diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..19a4d98b37092ac8fa071c93f9a30f2928e07961 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..503e09a0f3b58199172c4535309492f602de28c0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +resources/web-demo.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5798ad0202d7650fe4a8a59fbb949fa12e9d279a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,63 @@ +name: 🐞 Bug/Help +description: File a bug/issue +title: "[BUG/Help] " +labels: [] +body: +- type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the bug you encountered. + options: + - label: I have searched the existing issues + required: true +- type: textarea + attributes: + label: Current Behavior + description: | + A concise description of what you're experiencing, with screenshot attached if possible. + Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. + validations: + required: true +- type: textarea + attributes: + label: Expected Behavior + description: A concise description of what you expected to happen. + validations: + required: false +- type: textarea + attributes: + label: Steps To Reproduce + description: Steps to reproduce the behavior. + placeholder: | + 1. In this environment... + 2. With this config... + 3. Run '...' + 4. See error... + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + examples: + - **OS**: Ubuntu 20.04 + - **Python**: 3.8 + - **Transformers**: 4.26.1 + - **PyTorch**: 1.12 + - **CUDA Support**: True + value: | + - OS: + - Python: + - Transformers: + - PyTorch: + - CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) : + render: markdown + validations: + required: true +- type: textarea + attributes: + label: Anything else? + description: | + Links? References? Anything that will give us more context about the issue you are encountering! + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..ec4bb386bcf8a4946923eff961cc7cdf70c0aedf --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..19725d67fe35d310a629f542a1f40504739f45f5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,26 @@ +name: Feature request +description: Suggest an idea for this project +title: "[Feature] <title>" +labels: [] +body: +- type: textarea + attributes: + label: Is your feature request related to a problem? Please describe. + description: | + A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + validations: + required: false +- type: textarea + attributes: + label: Solutions + description: | + Describe the solution you'd like + A clear and concise description of what you want to happen. + validations: + required: true +- type: textarea + attributes: + label: Additional context + description: Add any other context or screenshots about the feature request here. + validations: + required: false diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..52a7f7602c46a2af4c2ff8b04f7c21051bd58762 --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +history/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Mac system file +model/ +.idea \ No newline at end of file diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 0000000000000000000000000000000000000000..cdaf4317791f6d76f25686326ea41aa36f97d3f0 --- /dev/null +++ b/FAQ.md @@ -0,0 +1,15 @@ +## Q1 + +**Mac直接加载量化后的模型出现提示 `clang: error: unsupported option '-fopenmp'** + +这是由于Mac由于本身缺乏omp导致的,此时可运行但是单核。需要单独安装 openmp 依赖,即可在Mac下使用OMP: + +```bash +# 参考`https://mac.r-project.org/openmp/` +## 假设: gcc(clang)是14.x版本,其他版本见R-Project提供的表格 +curl -O https://mac.r-project.org/openmp/openmp-14.0.6-darwin20-Release.tar.gz +sudo tar fvxz openmp-14.0.6-darwin20-Release.tar.gz -C / +``` +此时会安装下面几个文件:`/usr/local/lib/libomp.dylib`, `/usr/local/include/ompt.h`, `/usr/local/include/omp.h`, `/usr/local/include/omp-tools.h`。 + +> 注意:如果你之前运行`ChatGLM`项目失败过,最好清一下Huggingface的缓存,i.e. 默认下是 `rm -rf ${HOME}/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4`。由于使用了`rm`命令,请明确知道自己在删除什么。 \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ac4aee55c0698300d21541d5395d452016585f7a --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright Zhengxiao Du + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/MODEL_LICENSE b/MODEL_LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1b8899b0fb6d5389dca0ff15a20e98998ca90984 --- /dev/null +++ b/MODEL_LICENSE @@ -0,0 +1,65 @@ +The ChatGLM-6B License + +1. 定义 + +“许可方”是指分发其软件的 ChatGLM-6B 模型团队。 + +“软件”是指根据本许可提供的 ChatGLM-6B 模型参数。(不包括二代模型 ChatGLM2-6B 以及后续模型) + +2. 许可授予 + +根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 + +上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 + +3.限制 + +您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 + +您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 + +4.免责声明 + +本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 + +5. 责任限制 + +除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 + +6.争议解决 + +本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 + +请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 + +1. Definitions + +“Licensor” means the ChatGLM-6B Model Team that distributes its Software. + +“Software” means the ChatGLM-6B model parameters made available under this license (does not include the second-generation model ChatGLM2-6B and subsequent models). + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn. diff --git a/PROJECT.md b/PROJECT.md new file mode 100644 index 0000000000000000000000000000000000000000..334ddcc89b3a3a6bddfe450e91fb36a9059700ca --- /dev/null +++ b/PROJECT.md @@ -0,0 +1,37 @@ +# 友情链接 + +对 ChatGLM 进行加速或者重新实现的开源项目: +* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最高可以实现 9000+ tokens/s 的推理速度 +* [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。 +* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU +* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署 +* [InferLLM](https://github.com/MegEngine/InferLLM):轻量级 C++ 推理,可以实现本地 x86,Arm 处理器上实时聊天,手机上也同样可以实时运行,运行内存只需要 4G + + + +基于或使用了 ChatGLM-6B 的开源项目: +* [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): 支持ChatGLM-6B的学术写作与编程工具箱,具有模块化和多线程调用LLM的特点,可并行调用多种LLM。 +* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能 +* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM +* [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain):中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成,增加web search功能、知识库选择功能和支持知识增量更新 +* [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。 +* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答 +* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能) +* [Chuanhu Chat](https://github.com/GaiZhenbiao/ChuanhuChatGPT): 为各个大语言模型和在线模型API提供美观易用、功能丰富、快速部署的用户界面,支持ChatGLM-6B。 +* [ChatGLM-6B-Engineering](https://github.com/LemonQu-GIT/ChatGLM-6B-Engineering):基于 ChatGLM-6B 后期调教,网络爬虫及 [Stable Diffusion](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 实现的网络搜索及图片生成 +* [ChatGLM-OpenAI-API](https://github.com/ninehills/chatglm-openai-api): 将 ChatGLM-6B 封装为 OpenAI API 风格,并通过 ngrok/cloudflare 对外提供服务,从而将 ChatGLM 快速集成到 OpenAI 的各种生态中。 +* [ChatSQL](https://github.com/cubenlp/ChatSQL): 基于ChatGLM+SBERT实现NL2SQL本地化,并直接连接数据库查询数据返回结果,使得生成的SQL语句更具有实用性。 + +对 ChatGLM-6B 进行微调的开源项目: +* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题 +* [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning):实现了ChatGLM-6B模型的监督微调和完整RLHF训练,汇总10余种指令数据集和3种微调方案,实现了4/8比特量化和模型权重融合,提供微调模型快速部署方法。 +* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。 +* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf) + + +针对 ChatGLM-6B 的教程/文档: +* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md) +* [搭建深度学习docker容器以运行 ChatGLM-6B - Luck_zy](https://www.luckzym.com/tags/ChatGLM-6B/) + +如果你有其他好的项目/教程的话,欢迎参照上述格式添加到 README 中并提出 [Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)。 + diff --git a/README.md b/README.md index 1785d3a77b6ad02d42a61b5458e2f5c5b24120a1..acd8ab4540a75ec6c77ca8e283a03f796aca3b1f 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,397 @@ --- title: FORAI -emoji: 🚀 -colorFrom: purple -colorTo: blue +app_file: web_demo_old.py sdk: gradio sdk_version: 3.40.1 -app_file: app.py -pinned: false --- +# ChatGLM-6B -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +<p align="center"> + 🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br> +</p> +<p align="center"> + 👋 加入我们的 <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> 和 <a href="resources/WECHAT.md" target="_blank">WeChat</a> +</p> + +*Read this in [English](README_en.md).* + +## 介绍 + +ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。 +ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答,更多信息请参考我们的[博客](https://chatglm.cn/blog)。欢迎通过 [chatglm.cn](https://chatglm.cn) 体验更大规模的 ChatGLM 模型。 + +为了方便下游开发者针对自己的应用场景定制模型,我们同时实现了基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调方法 [(使用指南)](ptuning/README.md) ,INT4 量化级别下最低只需 7GB 显存即可启动微调。 + +ChatGLM-6B 权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。 + +想让 ChatGLM-6B 更符合你的应用场景?欢迎参与 [Badcase 反馈计划](improve/README.md)。 + +----- + +ChatGLM-6B 开源模型旨在与开源社区一起推动大模型技术发展,恳请开发者和大家遵守[开源协议](MODEL_LICENSE),勿将开源模型和代码及基于开源项目产生的衍生物用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。**目前,本项目团队未基于 ChatGLM-6B 开发任何应用,包括网页端、安卓、苹果 iOS 及 Windows App 等应用。** + +尽管模型在训练的各个阶段都尽力确保数据的合规性和准确性,但由于 ChatGLM-6B 模型规模较小,且模型受概率随机性因素影响,无法保证输出内容的准确性,且模型易被误导(详见[局限性](README.md#局限性))。**本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。** + +## 更新信息 +**[2023/07/25]** 发布 [CodeGeeX2](https://github.com/THUDM/CodeGeeX2) ,基于 ChatGLM2-6B 的代码生成模型,代码能力全面提升,更多特性包括: + +* **更强大的代码能力**:CodeGeeX2-6B 进一步经过了 600B 代码数据预训练,相比 CodeGeeX 一代模型,在代码能力上全面提升,[HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) 评测集的六种编程语言均大幅提升 (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%),在Python上达到 35.9\% 的 Pass@1 一次通过率,超越规模更大的 StarCoder-15B。 +* **更优秀的模型特性**:继承 ChatGLM2-6B 模型特性,CodeGeeX2-6B 更好支持中英文输入,支持最大 8192 序列长度,推理速度较一代 大幅提升,量化后仅需6GB显存即可运行,支持轻量级本地化部署。 +* **更全面的AI编程助手**:CodeGeeX插件([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex))后端升级,支持超过100种编程语言,新增上下文补全、跨文件补全等实用功能。结合 Ask CodeGeeX 交互式AI编程助手,支持中英文对话解决各种编程问题,包括且不限于代码解释、代码翻译、代码纠错、文档生成等,帮助程序员更高效开发。 + +**[2023/06/25]** 发布 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B),ChatGLM-6B 的升级版本,在保留了了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B 引入了如下新特性: + +1. **更强大的性能**:基于 ChatGLM 初代模型的开发经验,我们全面升级了 ChatGLM2-6B 的基座模型。ChatGLM2-6B 使用了 [GLM](https://github.com/THUDM/GLM) 的混合目标函数,经过了 1.4T 中英标识符的预训练与人类偏好对齐训练,[评测结果](#评测结果)显示,相比于初代模型,ChatGLM2-6B 在 MMLU(+23%)、CEval(+33%)、GSM8K(+571%) 、BBH(+60%)等数据集上的性能取得了大幅度的提升,在同尺寸开源模型中具有较强的竞争力。 +2. **更长的上下文**:基于 [FlashAttention](https://github.com/HazyResearch/flash-attention) 技术,我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。但当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限,我们会在后续迭代升级中着重进行优化。 +3. **更高效的推理**:基于 [Multi-Query Attention](http://arxiv.org/abs/1911.02150) 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%,INT4 量化下,6G 显存支持的对话长度由 1K 提升到了 8K。 + +更多信息参见 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B)。 + +**[2023/06/14]** 发布 [WebGLM](https://github.com/THUDM/WebGLM),一项被接受于KDD 2023的研究工作,支持利用网络信息生成带有准确引用的长回答。 + +![](resources/webglm.jpg) + +**[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。 + +![](resources/visualglm.png) + +可以通过本仓库中的 [cli_demo_vision.py](cli_demo_vision.py) 和 [web_demo_vision.py](web_demo_vision.py) 来运行命令行和网页 Demo。注意 VisualGLM-6B 需要额外安装 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) 和 torchvision。更多信息参见 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B)。 + +**[2023/05/15]** 更新 v1.1 版本 checkpoint,训练数据增加英文指令微调数据以平衡中英文数据比例,解决英文回答中夹杂中文词语的现象。 + +<details><summary><b>以下是更新前后的英文问题对比:</b></summary> + +* 问题:Describe a time when you had to make a difficult decision. + - v1.0: + ![](resources/english-q1-old.png) + - v1.1: + ![](resources/english-q1-new.png) +* 问题:Describe the function of a computer motherboard + - v1.0: + ![](resources/english-q2-old.png) + - v1.1: + ![](resources/english-q2-new.png) +* 问题:Develop a plan to reduce electricity usage in a home. + - v1.0: + ![](resources/english-q3-old.png) + - v1.1: + ![](resources/english-q3-new.png) +* 问题:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语 + - v1.0: + ![](resources/english-q4-old.png) + - v1.1: + ![](resources/english-q4-new.png) +</details> + +更多更新信息参见 [UPDATE.md](UPDATE.md) + +## 友情链接 +对 ChatGLM 进行加速的开源项目: +* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最高可以实现 9000+ tokens/s 的推理速度 +* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU +* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署 +* [InferLLM](https://github.com/MegEngine/InferLLM):轻量级 C++ 推理,可以实现本地 x86,Arm 处理器上实时聊天,手机上也同样可以实时运行,运行内存只需要 4G + +基于或使用了 ChatGLM-6B 的开源项目: +* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答 +* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能 +* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM +* [Chuanhu Chat](https://github.com/GaiZhenbiao/ChuanhuChatGPT): 为各个大语言模型和在线模型API提供美观易用、功能丰富、快速部署的用户界面,支持ChatGLM-6B。 + +支持 ChatGLM-6B 和相关应用在线训练的示例项目: +* [ChatGLM-6B 的部署与微调教程](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e) +* [ChatGLM-6B 结合 langchain 实现本地知识库 QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59) + +第三方评测: +* [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986) + +更多开源项目参见 [PROJECT.md](PROJECT.md) + +## 使用方式 + +### 硬件需求 + +| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) | +| -------------- | ------------------------- | --------------------------------- | +| FP16(无量化) | 13 GB | 14 GB | +| INT8 | 8 GB | 9 GB | +| INT4 | 6 GB | 7 GB | +### 环境安装 + +使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。 + +此外,如果需要在 cpu 上运行量化后的模型,还需要安装 `gcc` 与 `openmp`。多数 Linux 发行版默认已安装。对于 Windows ,可在安装 [TDM-GCC](https://jmeubank.github.io/tdm-gcc/) 时勾选 `openmp`。 Windows 测试环境 `gcc` 版本为 `TDM-GCC 10.3.0`, Linux 为 `gcc 11.3.0`。在 MacOS 上请参考 [Q1](FAQ.md#q1)。 + +### 代码调用 + +可以通过如下代码调用 ChatGLM-6B 模型来生成对话: + +```python +>>> from transformers import AutoTokenizer, AutoModel +>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +>>> model = model.eval() +>>> response, history = model.chat(tokenizer, "你好", history=[]) +>>> print(response) +你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。 +>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history) +>>> print(response) +晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法: + +1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。 +2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。 +3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。 +4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。 +5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。 +6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。 + +如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。 +``` +模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以在 `from_pretrained` 的调用中增加 `revision="v1.1.0"` 参数。`v1.1.0` 是当前最新的版本号,完整的版本列表参见 [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log)。 + +### 从本地加载模型 +以上代码会由 `transformers` 自动下载模型实现和参数。完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b)。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。 + +从 Hugging Face Hub 下载模型需要先[安装Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行 +```Shell +git clone https://huggingface.co/THUDM/chatglm-6b +``` + +如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现 +```Shell +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b +``` +然后从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载模型参数文件,并将下载的文件替换到本地的 `chatglm-6b` 目录下。 + +将模型下载到本地之后,将以上代码中的 `THUDM/chatglm-6b` 替换为你本地的 `chatglm-6b` 文件夹的路径,即可从本地加载模型。 + +**Optional** 模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以执行 +```Shell +git checkout v1.1.0 +``` + +## Demo & API + +我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库: + +```shell +git clone https://github.com/THUDM/ChatGLM-6B +cd ChatGLM-6B +``` + +### 网页版 Demo + +![web-demo](resources/web-demo.gif) + +首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py): + +```shell +python web_demo.py +``` + +程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。 + +感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117). + +### 命令行 Demo + +![cli-demo](resources/cli-demo.png) + +运行仓库中 [cli_demo.py](cli_demo.py): + +```shell +python cli_demo.py +``` + +程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。 + +### API部署 +首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py): +```shell +python api.py +``` +默认部署在本地的 8000 端口,通过 POST 方法进行调用 +```shell +curl -X POST "http://127.0.0.1:8000" \ + -H 'Content-Type: application/json' \ + -d '{"prompt": "你好", "history": []}' +``` +得到的返回值为 +```shell +{ + "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。", + "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]], + "status":200, + "time":"2023-03-23 21:38:40" +} +``` + +## 低成本部署 +### 模型量化 +默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下: + +```python +# 按需修改,目前只支持 4/8 bit 量化 +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda() +``` + +进行 2 至 3 轮对话后,8-bit 量化下 GPU 显存占用约为 10GB,4-bit 量化下仅需 6GB 占用。随着对话轮数的增多,对应消耗显存也随之增长,由于采用了相对位置编码,理论上 ChatGLM-6B 支持无限长的 context-length,但总长度超过 2048(训练长度)后性能会逐渐下降。 + +模型量化会带来一定的性能损失,经过测试,ChatGLM-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。使用 [GPT-Q](https://arxiv.org/abs/2210.17323) 等量化方案可以进一步压缩量化精度/提升相同量化精度下的模型性能,欢迎大家提出对应的 Pull Request。 + +量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,INT4 量化后的模型仅需大概 5.2GB 的内存: +```python +# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8" +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() +``` +量化模型的参数文件也可以从[这里](https://cloud.tsinghua.edu.cn/d/674208019e314311ab5c/)手动下载。 + +### CPU 部署 +如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存) +```python +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float() +``` + +如果你的内存不足,可以直接加载量化后的模型: +```python +# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8" +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float() +``` + +如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) ,请[从本地加载模型](README.md#从本地加载模型) + +### Mac 部署 +对于搭载了 Apple Silicon 或者 AMD GPU 的Mac,可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly(正确的版本号应该是2.1.0.dev2023xxxx,而不是2.0.0)。 + +目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端: +```python +model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps') +``` + +加载半精度的 ChatGLM-6B 模型需要大概 13GB 内存。内存较小的机器(比如 16GB 内存的 MacBook Pro),在空余内存不足的情况下会使用硬盘上的虚拟内存,导致推理速度严重变慢。此时可以使用量化后的模型如 chatglm-6b-int4。因为 GPU 上量化的 kernel 是使用 CUDA 编写的,因此无法在 MacOS 上使用,只能使用 CPU 进行推理。 +```python +# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8" +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float() +``` +为了充分使用 CPU 并行,还需要[单独安装 OpenMP](FAQ.md#q1)。 + +### 多卡部署 +如果你有多张 GPU,但是每张 GPU 的显存大小都不足以容纳完整的模型,那么可以将模型切分在多张GPU上。首先安装 accelerate: `pip install accelerate`,然后通过如下方法加载模型: +```python +from utils import load_model_on_gpus +model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) +``` +即可将模型部署到两张 GPU 上进行推理。你可以将 `num_gpus` 改为你希望使用的 GPU 数。默认是均匀切分的,你也可以传入 `device_map` 参数来自己指定。 + +## 高效参数微调 +基于 [P-tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调。具体使用方法详见 [ptuning/README.md](ptuning/README.md)。 + +## ChatGLM-6B 示例 + +以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现! + +<details><summary><b>自我认知</b></summary> + +![](examples/self-introduction.png) + +</details> + +<details><summary><b>提纲写作</b></summary> + +![](examples/blog-outline.png) + +</details> + +<details><summary><b>文案写作</b></summary> + +![](examples/ad-writing-2.png) + +![](examples/comments-writing.png) + +</details> + +<details><summary><b>邮件写作助手</b></summary> + +![](examples/email-writing-1.png) + +![](examples/email-writing-2.png) + +</details> + +<details><summary><b>信息抽取</b></summary> + +![](examples/information-extraction.png) + +</details> + +<details><summary><b>角色扮演</b></summary> + +![](examples/role-play.png) + +</details> + +<details><summary><b>评论比较</b></summary> + +![](examples/sport.png) + +</details> + +<details><summary><b>旅游向导</b></summary> + +![](examples/tour-guide.png) + +</details> + +## 局限性 + +由于 ChatGLM-6B 的小规模,其能力仍然有许多局限性。以下是我们目前发现的一些问题: + +- 模型容量较小:6B 的小容量,决定了其相对较弱的模型记忆和语言能力。在面对许多事实性知识任务时,ChatGLM-6B 可能会生成不正确的信息;它也不擅长逻辑类问题(如数学、编程)的解答。 + <details><summary><b>点击查看例子</b></summary> + + ![](limitations/factual_error.png) + + ![](limitations/math_error.png) + + </details> + +- 产生有害说明或有偏见的内容:ChatGLM-6B 只是一个初步与人类意图对齐的语言模型,可能会生成有害、有偏见的内容。(内容可能具有冒犯性,此处不展示) + +- 英文能力不足:ChatGLM-6B 训练时使用的指示/回答大部分都是中文的,仅有极小一部分英文内容。因此,如果输入英文指示,回复的质量远不如中文,甚至与中文指示下的内容矛盾,并且出现中英夹杂的情况。 + +- 易被误导,对话能力较弱:ChatGLM-6B 对话能力还比较弱,而且 “自我认知” 存在问题,并很容易被误导并产生错误的言论。例如当前版本的模型在被误导的情况下,会在自我认知上发生偏差。 + <details><summary><b>点击查看例子</b></summary> + + ![](limitations/self-confusion_google.jpg) + + ![](limitations/self-confusion_openai.jpg) + + ![](limitations/self-confusion_tencent.jpg) + + </details> + +## 协议 + +本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。ChatGLM-6B 权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。 + +## 引用 + +如果你觉得我们的工作有帮助的话,请考虑引用下列论文 + +``` +@article{zeng2022glm, + title={Glm-130b: An open bilingual pre-trained model}, + author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others}, + journal={arXiv preprint arXiv:2210.02414}, + year={2022} +} +``` +``` +@inproceedings{du2022glm, + title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling}, + author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, + pages={320--335}, + year={2022} +} +``` diff --git a/README_en.md b/README_en.md new file mode 100644 index 0000000000000000000000000000000000000000..1e2c12b5254acb47d73ed5309c6a681473bfd331 --- /dev/null +++ b/README_en.md @@ -0,0 +1,356 @@ +# ChatGLM-6B + +<p align="center"> + 🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br> +</p> +<p align="center"> + 👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a> +</p> + +## Introduction + +ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). Welcome to use the larger ChatGLM model on [chatglm.cn](https://chatglm.cn) + +ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference. + +In order to facilitate downstream developers to customize the model for their own application scenarios, we also implements an parameter-efficient tuning method based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2)[(Guidelines)](ptuning/README_en.md). Tuning requires at least 7GB of GPU memory at INT4 quantization level. + +ChatGLM-6B weights are **completely open** for academic research, and **free commercial use** is also allowed after completing the [questionnaire](https://open.bigmodel.cn/mla/form). + +Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces. + +## Update +**[2023/07/25]** Release [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), which is based on ChatGLM2-6B and trained on more code data. It has the following features: + +* **More Powerful Coding Capabilities**: CodeGeeX2-6B has been further pre-trained on 600B code tokens, which has been comprehensively improved in coding capability compared to the first-generation. On the [HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) benchmark, all six languages have been significantly improved (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%), and in Python it reached 35.9% of Pass@1 one-time pass rate, surpassing the larger StarCoder-15B. +* **More Useful Features**: Inheriting the ChatGLM2-6B model features, CodeGeeX2-6B better supports both Chinese and English prompts, maximum 8192 sequence length, and the inference speed is significantly improved compared to the first-generation. After quantization, it only needs 6GB of GPU memory for inference, thus supports lightweight local deployment. +* **Comprehensive AI Coding Assistant**: The backend of CodeGeeX plugin ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)) is upgraded, supporting 100+ programming languages, and adding practical functions such as infilling and cross-file completion. Combined with the "Ask CodeGeeX" interactive AI coding assistant, it can be used to solve various programming problems via Chinese or English dialogue, including but not limited to code summarization, code translation, debugging, and comment generation, which helps increasing the efficiency of developpers. + +**[2023/06/25]** Release [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), the second-generation version of ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the following new features: + +1. **Stronger Performance**: Based on the development experience of the first-generation ChatGLM model, we have fully upgraded the base model of ChatGLM2-6B. ChatGLM2-6B uses the hybrid objective function of [GLM](https://github.com/THUDM/GLM), and has undergone pre-training with 1.4T bilingual tokens and human preference alignment training. The [evaluation results](README.md#evaluation-results) show that, compared to the first-generation model, ChatGLM2-6B has achieved substantial improvements in performance on datasets like MMLU (+23%), CEval (+33%), GSM8K (+571%), BBH (+60%), showing strong competitiveness among models of the same size. +2. **Longer Context**: Based on [FlashAttention](https://github.com/HazyResearch/flash-attention) technique, we have extended the context length of the base model from 2K in ChatGLM-6B to 32K, and trained with a context length of 8K during the dialogue alignment, allowing for more rounds of dialogue. However, the current version of ChatGLM2-6B has limited understanding of single-round ultra-long documents, which we will focus on optimizing in future iterations. +3. **More Efficient Inference**: Based on [Multi-Query Attention](http://arxiv.org/abs/1911.02150) technique, ChatGLM2-6B has more efficient inference speed and lower GPU memory usage: under the official implementation, the inference speed has increased by 42% compared to the first generation; under INT4 quantization, the dialogue length supported by 6G GPU memory has increased from 1K to 8K. + +Fore more information, please refer to [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B). + +**[2023/05/17]** Release [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B), a multimodal conversational language model supporting image understanding. + +![](resources/visualglm.png) + +You can run the command line and web demo through [cli_demo_vision.py](cli_demo_vision.py) and [web_demo_vision.py](web_demo_vision.py) in the repository. Note that VisualGLM-6B requires additional installation of [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) and torchvision. For more information, please refer to [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B). + +**[2023/05/15]** Update the checkpoint of v1.1 version, add English instruction data for training to balance the proportion of Chinese and English data, which solves the phenomenon of Chinese words mixed in English answers . + +<details><summary><b>The following is a comparison of English questions before and after the update</b></summary> + +* Question: Describe a time when you had to make a difficult decision. + - v1.0: + ![](resources/english-q1-old.png) + - v1.1: + ![](resources/english-q1-new.png) +* Question: Describe the function of a computer motherboard + - v1.0: + ![](resources/english-q2-old.png) + - v1.1: + ![](resources/english-q2-new.png) +* Question: Develop a plan to reduce electricity usage in a home. + - v1.0: + ![](resources/english-q3-old.png) + - v1.1: + ![](resources/english-q3-new.png) +* Question:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语 + - v1.0: + ![](resources/english-q4-old.png) + - v1.1: + ![](resources/english-q4-new.png) +</details> + +For more update info, please refer to [UPDATE.md](UPDATE.md). + +## Projects +Open source projects that accelerate ChatGLM: +* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): Inference acceleration for ChatGLM-6B, up to 9000+ tokens/s inference speed. +* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An MNN-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory +* [JittorLLMs](https://github.com/Jittor/JittorLLMs): Running ChatGLM-6B in FP16 with a minimum of 3G GPU memory or no GPU at all, with Linux, windows, and Mac support +* [InferLLM](https://github.com/MegEngine/InferLLM): Lightweight C++ inference, which can realize real-time chat on local x86 and Arm processors, and can also run in real time on mobile phones. It only requires 4G of running memory. + +Open source projects using ChatGLM-6B: +* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): ChatGLM application based on langchain, realizing Q&A based on extensible knowledge base +* [Wenda](https://github.com/l15y/wenda): Large-scale language model call platform, based on ChatGLM-6B to achieve ChatPDF-like functions +* [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): An academic writing and programming toolbox that supports ChatGLM-6B. It has the characteristics of modularization and multi-thread calling LLM, and can call multiple LLMs in parallel. +* [glm-bot](https://github.com/initialencounter/glm-bot): Connect ChatGLM to Koishi to call ChatGLM on major chat platforms + +Example projects supporting online training of ChatGLM-6B and related applications: +* [ChatGLM-6B deployment and fine-tuning tutorial](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e) +* [ChatGLM-6B combined with langchain to implement local knowledge base QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59) + +Third-party evaluation: +* [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986) + +For more open source projects, see [PROJECT.md](PROJECT.md). + +## Getting Started + +### Hardware Requirements + +| **Quantization Level** | **GPU Memory** | +|------------------------|----------------| +| FP16(no quantization) | 13 GB | +| INT8 | 10 GB | +| INT4 | 6 GB | + +### Environment Setup + +Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.27.1`, but theoretically any version no lower than `4.23.1` is acceptable. + +In addition, if you need to run the quantified model on the CPU, you also need to install `gcc` and `openmp`. Most Linux distributions are installed by default. For Windows, you can check `openmp` when installing [TDM-GCC](https://jmeubank.github.io/tdm-gcc/). On Windows testing environment, the `gcc` version is `TDM-GCC 10.3.0`, and on Linux is `gcc 11.3.0`. + +### Usage + +Generate dialogue with the following code + +```python +>>> from transformers import AutoTokenizer, AutoModel +>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +>>> model = model.eval() +>>> response, history = model.chat(tokenizer, "你好", history=[]) +>>> print(response) +你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。 +>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history) +>>> print(response) +晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法: + +1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。 +2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。 +3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。 +4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。 +5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。 +6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。 + +如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。 +``` +The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can add the `revision="v1.1.0"` parameter in the `from_pretrained` call. `v1.1.0` is the latest version number. For a complete list of versions, see [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log). + +### Load the model locally +The above code will automatically download the model implementation and checkpoints by [transformers](https://github.com/huggingface/transformers). The full model implementation can be found at [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b). If your network environment is poor, downloading model parameters may take a long time or even fail. At this point, you can download the model to the local first, and then load it from the local. + +To download models from Hugging Face Hub, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) , then run +```Shell +git clone https://huggingface.co/THUDM/chatglm-6b +``` + +After downloading the model locally, replace `THUDM/chatglm-6b` in the above code with the path of your local `chatglm-6b` folder to load the model locally. + +**Optional**: The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can execute +```Shell +git checkout v1.1.0 +``` + +## Demo & API + +We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with: + +```shell +git clone https://github.com/THUDM/ChatGLM-6B +cd ChatGLM-6B +``` + +### Web Demo + +![web-demo](resources/web-demo.gif) + +Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py): + +```shell +python web_demo.py +``` + +The program runs a web server and outputs the URL. Open the URL in the browser to use the web demo. + +Thanks to [@AdamBear](https://github.com/AdamBear) for implementing a web demo based on Streamlit, see [#117](https://github.com/THUDM/ChatGLM-6B/pull/117 ). + +#### CLI Demo + +![cli-demo](resources/cli-demo.png) + +Run [cli_demo.py](cli_demo.py) in the repo: + +```shell +python cli_demo.py +``` + +The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program. + +## API Deployment +First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo. +```shell +python api.py +``` +By default the api runs at the`8000`port of the local machine. You can call the API via +```shell +curl -X POST "http://127.0.0.1:8000" \ + -H 'Content-Type: application/json' \ + -d '{"prompt": "你好", "history": []}' +``` +The returned value is +```shell +{ + "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。", + "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]], + "status":200, + "time":"2023-03-23 21:38:40" +} +``` + +## Deployment + +### Quantization + +By default, the model parameters are loaded with FP16 precision, which require about 13GB of GPU memory. It your GPU memory is limited, you can try to load the model parameters with quantization: + +```python +# Change according to your hardware. Only support 4/8 bit quantization now. +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(8).cuda() +``` + +After 2 to 3 rounds of dialogue, the GPU memory usage is about 10GB under 8-bit quantization, and only 6GB under 4-bit quantization. As the number of dialogue rounds increases, the corresponding GPU memory consumption also increases. Due to the use of relative position encoding, ChatGLM-6B theoretically supports an infinitely long context-length, but the performance will gradually decline after the total length exceeds 2048 (training length). + +Model quantization brings a certain performance decline. After testing, ChatGLM-6B can still perform natural and smooth generation under 4-bit quantization. using [GPT-Q](https://arxiv.org/abs/2210.17323) etc. The quantization scheme can further compress the quantization accuracy/improve the model performance under the same quantization accuracy. You are welcome to submit corresponding Pull Requests. + +The quantization costs about 13GB of CPU memory to load the FP16 model. If your CPU memory is limited, you can directly load the quantized model, which costs only 5.2GB CPU memory: +```python +# For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8" +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() +``` + +### CPU Deployment + +If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory): + +```python +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float() +``` + +If your CPU memory is limited, you can directly load the quantized model: +```python +# For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8" +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float() +``` + +If your encounter the error `Could not find module 'nvcuda.dll'` or `RuntimeError: Unknown platform: darwin`(MacOS), please [load the model locally](README_en.md#load-the-model-locally). + +### Inference on Mac +For Macs (and MacBooks) with Apple Silicon, it is possible to use the MPS backend to run ChatGLM-6B on the GPU. First, you need to refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly. (The correct version number should be 2.1.0.dev2023xxxx, not 2.0.0). + +Currently you must [load the model locally](README_en.md#load-the-model-locally) on MacOS. Change the code to load the model from your local path, and use the mps backend: +```python +model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps') +``` + +Loading a FP16 ChatGLM-6B model requires about 13GB of memory. Machines with less memory (such as a MacBook Pro with 16GB of memory) will use the virtual memory on the hard disk when there is insufficient free memory, resulting in a serious slowdown in inference speed. At this time, a quantized model such as chatglm-6b-int4 can be used. Because the quantized kernel on the GPU is written in CUDA, it cannot be used on MacOS, and can only be inferred using the CPU: + +```python +# For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8" +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float() +``` + +### Multi-GPU Deployment +If you have multiple GPUs, but the memory size of each GPU is not sufficient to accommodate the entire model, you can split the model across multiple GPUs. + +First, install accelerate: `pip install accelerate`, and then load the model using the following method: +```python +from utils import load_model_on_gpus +model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) +``` + +This will deploy the model onto two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. By default, the model is split evenly, but you can also specify the `device_map` parameter to customize the splitting. + +## Parameter-efficient Tuning +Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it. + +## ChatGLM-6B Examples + +The following are some Chinese examples with `web_demo.py`. Welcome to explore more possibility with ChatGLM-6B. + +<details><summary><b>Self Cognition</b></summary> + +![](examples/self-introduction.png) + +</details> + +<details><summary><b>Outline</b></summary> + +![](examples/blog-outline.png) + +</details> + +<details><summary><b>Ad</b></summary> + +![](examples/ad-writing-2.png) + +![](examples/comments-writing.png) + +</details> + +<details><summary><b>Email</b></summary> + +![](examples/email-writing-1.png) + +![](examples/email-writing-2.png) + +</details> + +<details><summary><b>Information Extraction</b></summary> + +![](examples/information-extraction.png) + +</details> + +<details><summary><b>Role Play</b></summary> + +![](examples/role-play.png) + +</details> + +<details><summary><b>Comparison</b></summary> + +![](examples/sport.png) + +</details> + +<details><summary><b>Travel Guide</b></summary> + +![](examples/tour-guide.png) + +</details> + +## License + +This repository is licensed under the [Apache-2.0 License](LICENSE). The use of ChatGLM-6B model weights is subject to the [Model License](MODEL_LICENSE)。 + +## Citation + +If you find our work useful, please consider citing the following papers: + +``` +@inproceedings{ + zeng2023glm-130b, + title={{GLM}-130B: An Open Bilingual Pre-trained Model}, + author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang}, + booktitle={The Eleventh International Conference on Learning Representations (ICLR)}, + year={2023}, + url={https://openreview.net/forum?id=-Aw0rrrPUF} +} +``` + +``` +@inproceedings{du2022glm, + title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling}, + author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, + pages={320--335}, + year={2022} +} +``` diff --git a/UPDATE.md b/UPDATE.md new file mode 100644 index 0000000000000000000000000000000000000000..3bb2a80bef9fc62accaddc8b2cb64514cd873604 --- /dev/null +++ b/UPDATE.md @@ -0,0 +1,86 @@ +## 更新信息 +**[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。 + +![](resources/visualglm.png) + +可以通过本仓库中的 [cli_demo_vision.py](cli_demo_vision.py) 和 [web_demo_vision.py](web_demo_vision.py) 来运行命令行和网页 Demo。注意 VisualGLM-6B 需要额外安装 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) 和 torchvision。更多信息参见 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B)。 + +**[2023/05/15]** 更新 v1.1 版本 checkpoint,训练数据增加英文数据以平衡中英文数据比例,解决英文回答中夹杂中文词语的现象。 + +<details><summary><b>以下是更新前后的英文问题对比:</b></summary> + +* 问题:Describe a time when you had to make a difficult decision. + - v1.0: + ![](resources/english-q1-old.png) + - v1.1: + ![](resources/english-q1-new.png) +* 问题:Describe the function of a computer motherboard + - v1.0: + ![](resources/english-q2-old.png) + - v1.1: + ![](resources/english-q2-new.png) +* 问题:Develop a plan to reduce electricity usage in a home. + - v1.0: + ![](resources/english-q3-old.png) + - v1.1: + ![](resources/english-q3-new.png) +* 问题:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语 + - v1.0: + ![](resources/english-q4-old.png) + - v1.1: + ![](resources/english-q4-new.png) +</details> + +**[2023/04/16]** 增加 INT8 量化后的模型 [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8)。增加多卡部署(感谢 [@Cherrysaber](https://github.com/Cherrysaber))。 + +**[2023/04/06]** 优化web demo的界面(感谢 [@tuteng0915](https://github.com/tuteng0915))。移除embedding中的image token以减小显存占用(需要更新模型文件`pytorch_model-00001-of-00008.bin`和`pytorch_model-00008-of-00008.bin`,感谢 [@silverriver](https://github.com/silverriver) 提出的想法)。去掉了对 `icetk` 的依赖(需要更新模型文件`ice_text.model`)。 + +**[2023/03/31]** 增加基于 [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调实现,INT4 量化级别下最低只需 7GB 显存即可进行模型微调。详见[高效参数微调方法](ptuning/README.md)。 + +**[2023/03/23]** 增加 API 部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。~~增加 Embedding 量化模型 [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)~~ (已停止维护)。增加配备 Apple Silicon 芯片的 Mac 上 GPU 加速的支持。 + +**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加 INT4 量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4) + + +## Update +**[2023/05/17]** Release [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B), a multimodal conversational language model supporting image understanding. + +![](resources/visualglm.png) + +You can run the command line and web demo through [cli_demo_vision.py](cli_demo_vision.py) and [web_demo_vision.py](web_demo_vision.py) in the repository. Note that VisualGLM-6B requires additional installation of [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) and torchvision. For more information, please refer to [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B). + +**[2023/05/15]** Update the checkpoint of v1.1 version, add English instruction data for training to balance the proportion of Chinese and English data, which solves the phenomenon of Chinese words mixed in English answers . + +<details><summary><b>The following is a comparison of English questions before and after the update</b></summary> + +* Question: Describe a time when you had to make a difficult decision. + - v1.0: + ![](resources/english-q1-old.png) + - v1.1: + ![](resources/english-q1-new.png) +* Question: Describe the function of a computer motherboard + - v1.0: + ![](resources/english-q2-old.png) + - v1.1: + ![](resources/english-q2-new.png) +* Question: Develop a plan to reduce electricity usage in a home. + - v1.0: + ![](resources/english-q3-old.png) + - v1.1: + ![](resources/english-q3-new.png) +* Question:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语 + - v1.0: + ![](resources/english-q4-old.png) + - v1.1: + ![](resources/english-q4-new.png) +</details> + +**[2023/04/16]** Added INT8 quantized model [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8). Added multi-GPU deployment (thanks to [@Cherrysaber](https://github.com/Cherrysaber)). + +**[2023/04/06]** Improve the web demo interface (thanks to [@tuteng0915](https://github.com/tuteng0915)). Remove the image tokens in the embedding layer to reduce the memory usage (need to update the model files `pytorch_model-00001-of-00008.bin` and `pytorch_model-00008-of-00008.bin`, thanks to [@silverriver](https:/ /github.com/silverriver) for proposing the idea). Removed dependency on `icetk` (need to update model file `ice_text.model`). + +**[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details. + +**[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon. + +**[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4). \ No newline at end of file diff --git a/api.py b/api.py new file mode 100644 index 0000000000000000000000000000000000000000..693c70acc4adf397375ea8b24660f9592072809f --- /dev/null +++ b/api.py @@ -0,0 +1,56 @@ +from fastapi import FastAPI, Request +from transformers import AutoTokenizer, AutoModel +import uvicorn, json, datetime +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI() + + +@app.post("/") +async def create_item(request: Request): + global model, tokenizer + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get('prompt') + history = json_post_list.get('history') + max_length = json_post_list.get('max_length') + top_p = json_post_list.get('top_p') + temperature = json_post_list.get('temperature') + response, history = model.chat(tokenizer, + prompt, + history=history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95) + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": response, + "history": history, + "status": 200, + "time": time + } + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + torch_gc() + return answer + + +if __name__ == '__main__': + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model.eval() + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/cli_demo.py b/cli_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3559840c6a746735f5f3004d8245517f758329b2 --- /dev/null +++ b/cli_demo.py @@ -0,0 +1,58 @@ +import os +import platform +import signal +from transformers import AutoTokenizer, AutoModel +import readline + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + history = [] + global stop_stream + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == "__main__": + main() diff --git a/cli_demo_vision.py b/cli_demo_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..b556be0ff338946a492b07fea4a732ffc71fc64f --- /dev/null +++ b/cli_demo_vision.py @@ -0,0 +1,64 @@ +import os +import platform +import signal +import sys + +from transformers import AutoTokenizer, AutoModel +import readline + +tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def build_prompt(history, prefix): + prompt = prefix + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + global stop_stream + while True: + history = [] + prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序" + print(prefix) + image_path = input("\n请输入图片路径:") + if image_path == "stop": + break + prefix = prefix + "\n" + image_path + query = "描述这张图片。" + while True: + count = 0 + for response, history in model.stream_chat(tokenizer, image_path, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history, prefix), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history, prefix), flush=True) + query = input("\n用户:") + if query.strip() == "clear": + break + if query.strip() == "stop": + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/examples/ad-writing-2.png b/examples/ad-writing-2.png new file mode 100644 index 0000000000000000000000000000000000000000..4a970605bdcced07b296b9944bb3812ae963aaf8 Binary files /dev/null and b/examples/ad-writing-2.png differ diff --git a/examples/blog-outline.png b/examples/blog-outline.png new file mode 100644 index 0000000000000000000000000000000000000000..c26a1a90528581bd948134be61ffae63fc164242 Binary files /dev/null and b/examples/blog-outline.png differ diff --git a/examples/comments-writing.png b/examples/comments-writing.png new file mode 100644 index 0000000000000000000000000000000000000000..3c92af6461eba0b41f5f8045d9861360728a72a7 Binary files /dev/null and b/examples/comments-writing.png differ diff --git a/examples/email-writing-1.png b/examples/email-writing-1.png new file mode 100644 index 0000000000000000000000000000000000000000..2015edc55dd52e73b2705ec937a6e3519bb4f797 Binary files /dev/null and b/examples/email-writing-1.png differ diff --git a/examples/email-writing-2.png b/examples/email-writing-2.png new file mode 100644 index 0000000000000000000000000000000000000000..d6abc5e9302b7a12908d293291eb5301070fdbfa Binary files /dev/null and b/examples/email-writing-2.png differ diff --git a/examples/information-extraction.png b/examples/information-extraction.png new file mode 100644 index 0000000000000000000000000000000000000000..8c866df4b236e378f796584c84094e26fff2f1ac Binary files /dev/null and b/examples/information-extraction.png differ diff --git a/examples/role-play.png b/examples/role-play.png new file mode 100644 index 0000000000000000000000000000000000000000..5338c97c232504de34f5194dd86c1a4b4363701a Binary files /dev/null and b/examples/role-play.png differ diff --git a/examples/self-introduction.png b/examples/self-introduction.png new file mode 100644 index 0000000000000000000000000000000000000000..d0d372ca93120a505461ea483686a20068c269a4 Binary files /dev/null and b/examples/self-introduction.png differ diff --git a/examples/sport.png b/examples/sport.png new file mode 100644 index 0000000000000000000000000000000000000000..a900b7bb207f0b26499212ecb62e605444849839 Binary files /dev/null and b/examples/sport.png differ diff --git a/examples/tour-guide.png b/examples/tour-guide.png new file mode 100644 index 0000000000000000000000000000000000000000..6265f3f682a456e284db1794e88bd669315e43ca Binary files /dev/null and b/examples/tour-guide.png differ diff --git a/improve/README.md b/improve/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6da2f30a5c1a1e9b9264ecdaeab6e037494dde82 --- /dev/null +++ b/improve/README.md @@ -0,0 +1,14 @@ +# ChatGLM-6B Badcase 反馈计划 +ChatGLM-6B 自3月14号发布以来受到了广大开发者和用户的喜爱,截至4月22号 GitHub 的 star 数达到 2 万,各个渠道模型的累计下载量过 100 万,并连续 12 天居 Hugging Face 全球大模型下载榜第一名。 与此同时,有一批基于 ChatGLM-6B 的[优秀开源项目](https://github.com/THUDM/ChatGLM-6B)出现,在各个平台也引起了广泛好评和关注。此外,基于 GLM-130B 的千亿对话模型 ChatGLM 也自3月14号开始了第一阶段的邀请制内测,得到了内测用户的好评和支持。谢谢大家对 ChatGLM 及其 6B 开源版本的大力支持! + +接下来,我们想邀请大家一起推动 ChatGLM-6B 的进一步提升,一起推动模型的发展。尽管ChatGLM-6B已初具符合人类偏好的问答对话能力,在相当多的指令和问题上,其回答仍存在不理解复杂指令和任务含义,缺乏领域概念理解,事实性错误,生成有害内容,对话上下文不一致等诸多问题。尽管我们提供的[微调代码](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)能够让用户通过自主训练修复部分问题,但因为神经网络的[灾难性遗忘](https://picture.iczhiku.com/weixin/message1587593113355.html)问题,微调后的模型往往会失去在通用领域的对话能力或者因数据较少而缺乏泛化能力。为了解决这些问题,进一步提升 ChatGLM-6B 的能力,我们启动了 ChatGLM-6B Badcase 反馈计划。 + +具体来说,对于在使用 ChatGLM-6B 过程中遇到的表现不佳的Badcase对应的具体指令和提问,您可以修改或从头撰写您认为合适的正确答案,并反馈给我们改进 ChatGLM-6B。**请您确保提交的数据不包含任何个人信息、商业秘密或可能危害国家安全、侵害第三方知识产权的内容。** 我们会定期(每2-4周)对数据的有用性与正确性进行筛选,将筛选通过的数据,与通用域的对话数据一起加入到模型训练中,并**更新发布开源的模型参数**。**您提供的数据无论是否筛选通过,除非获得您的许可或根据国家法律规定和监管要求外,我们不会将您提供的数据对外公开。** + +您提供的数据如被筛选通过,您将同时优先获得最新版本的 ChatGLM-6B 模型的体验资格。此外,如果您愿意,您的用户名还将出现在 ChatGLM-6B Github页面的数据贡献者名单中,以此感谢您对推进大模型开源事业发展的帮助。您的支持和建议将为我们优化大模型提供源源不断的动力,在此向您表达我们最深的敬意与感谢! + +如果您希望参与反馈,请填写[问卷](https://www.wjx.cn/vm/rAoGx9X.aspx#)并按照具体要求上传。提交的数据为 jsonline 格式,每行的内容为 +```json lines +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿x s裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"} +``` +其中,`prompt` 部分为模型的输入,`response` 部分为期望的模型输出。为了保证模型的性能,请在输入中尽可能详细地表述任务的类型和期望的输出格式。针对某项具体的任务,为了使模型得到充分的训练,一般需要100条左右的训练数据。 diff --git a/improve/data_sample.jsonl b/improve/data_sample.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..695df919be607a00345d80d8e07e5cd4424fff03 --- /dev/null +++ b/improve/data_sample.jsonl @@ -0,0 +1,100 @@ +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖", "response": "圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则", "response": "宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙", "response": "踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出<UNK>微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接", "response": "想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#h*材质#蚕丝*风格#复古*图案#条纹*图案#复古*图案#撞色*衣样式#衬衫*衣领型#小立领", "response": "小女人十足的条纹衬衣,缎面一点点的复古,还有蓝绿色这种高级气质复古色,真丝材质,撞色竖条纹特别的现代感味道,直h型的裁剪和特别的衣长款式,更加独立性格。双层小立领,更显脸型。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#网纱*颜色#粉红色*图案#线条*图案#刺绣*裙腰型#高腰*裙长#连衣裙*裙袖长#短袖*裙领型#圆领", "response": "这款连衣裙,由上到下都透出一丝迷人诱惑的女性魅力,经典圆领型,开口度恰好,露出你的迷人修长的脖颈线条,很是优雅气质,短袖设计,在这款上竟是撩人美貌,高腰线,散开的裙摆,到小腿的长度,遮住了腿部粗的部分,对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#文字*图案#印花*衣样式#卫衣", "response": "一款非常简洁大方的纯色卫衣,设计点在于胸前的“<UNK><UNK>”的中文字印花,新颖特别,让人眼前一亮。简单又吸睛的款式,而且不失时髦感,很适合个性年轻人。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#黑色*颜色#灰色*颜色#姜黄色*风格#休闲*图案#线条*图案#撞色*衣样式#毛衣*衣袖型#落肩袖", "response": "看惯了灰色的冷淡和黑色的沉闷感,来一点醒目的彩色增添点活力吧。亮眼又吸睛的姜黄色色调,嫩肤显白非常的有设计感。趣味的撞色和宽松的版型相交辉映,修饰身形小缺点的同时,时尚又百搭。优雅的落肩袖,轻松修饰肩部线条,让毛衣上身凸显出一丝慵懒随性的休闲感,时尚魅力尽显。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*风格#潮*图案#印花*图案#撞色*衣样式#衬衫*衣领型#圆领*衣长#中长款*衣长#常规*衣袖长#无袖", "response": "黑与白,两种最极端的颜色却轻松搭配成了经典,就像此款衬衣,无需过多装饰,仅色调就足够醒目个性,受潮<UNK>所喜欢。做了无袖中长款的样式,走路带风的感觉着实不错,圆领的设计,不是常规的衬衫领,少了点正式反而有种休闲感觉,适合孩子们穿着。后背大面积撞色印花装点,是时尚潮流的象征,也让衣衣不至于单调,轻松就能穿出彩。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#街头*风格#休闲*风格#朋克*图案#字母*图案#文字*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#对称", "response": "个性休闲风的连帽卫衣造型时髦大方,宽松的版型剪裁让肉肉的小宝贝也可以穿着,保暖的连帽设计时刻给予宝贝温柔的呵护,袖子和后背别致时髦的字母印花点缀,满满的街头元素融入,演绎休闲朋克风,对称的小口袋美观大方,方便放置更多的随身物品。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#链条", "response": "简单大气的设计,不费吹灰之力就能搭配的时髦范儿。时尚的配色一点都不觉得平淡了,有种浑然天成的大气感。强调了整体的装饰,和谐又不失个性,搭配裤装帅气十足,搭配裙子精致优雅。链条和肩带的搭配让使用感更加舒服,单肩手提都好看。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#显瘦*材质#牛仔布*颜色#深蓝色*风格#复古*图案#复古*图案#线条*裤腰型#高腰*裤口#微喇裤", "response": "深蓝色的高腰牛仔裤,修身的款式勾勒出纤细的美腿。牛仔裤的裤脚设计<UNK>张开的喇叭型,巧妙地修饰了小腿的线条,洋溢着复古的年代感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#潮*风格#性感*图案#条纹*图案#蝴蝶结*衣样式#衬衫*衣领型#一字领*衣门襟#系带*衣款式#不对称", "response": "这是一件显得特别清新的衬衣,采用了条纹的设计,给予人一种甜美可人的气质。并且融合了别致的斜肩一字领设计,高调的展示出性感的锁骨,将迷人的香肩展现在外,性感中不失去清纯的气息。袖口处的蝴蝶结系带装饰,增添了俏皮的韵味,简洁大方。且在下摆处采用了不对称的设计,增强了视觉效果,更显潮流。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#复古*图案#复古*裤型#直筒裤*裤款式#纽扣*裤腰型#高腰", "response": "作为基础款单品,牛仔裤也<UNK><UNK>,想要呈现给大家的是——每次搭配都有新感觉。裤子经过复古做旧处理,风格鲜明,也很注重细节,连纽扣也做了统一的做旧处理,融入个性十足的磨破设计,高腰直筒basic裤型,修饰身材,穿出高挑长腿。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*图案#线条*图案#刺绣*衣样式#针织衫*衣领型#v领", "response": "一款温暖柔软又富有弹性的针织衫,不仅可以抵御严寒侵袭,还能更好地进行搭配。v领的设计,能勾勒出迷人的天鹅颈以及衬托出娇小的脸型。宽松又别致的剪裁,能从视觉上显露纤长的下半身,起到显瘦的效果。直筒造型的袖子,修饰出优美的手臂线条,衣身上的方格刺绣,时尚又吸睛。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#绿色*风格#清新*图案#线条*衣样式#衬衫*衣领型#翻领", "response": "绿色的衣身上镶嵌着<UNK>,就是这款衬衫最大的迷人之处,“红花配绿叶”般的色调,将清新气息阐述的淋漓尽致。经典的翻领更是贴心,修饰颈部线条的同时,尽显精致干练的气质,出街轻松凹造型。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*图案#字母*图案#文字*图案#印花*图案#撞色*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "这款外套采用了撞色拉链织带以及字母印花设计。这两种元素的融入使外套不会显得过于单调沉闷,吸睛而亮眼,充满年轻与朝气感,非常减龄。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*版型#h*风格#复古*图案#复古*图案#刺绣*裙长#连衣裙*裙袖长#长袖*裙领型#翻领*裙衣门襟#单排扣", "response": "本款连衣裙整体采用h型的轮廓设计,藏肉显瘦,不挑身材,适合各种身形的人穿着。小翻领的领口设计,使得本款连衣裙穿在身上看起来十分的精神帅气,具有青春活力。单排扣的衣门襟设计,又给本款连衣裙带来了一丝的复古味道。裙身上的刺绣花朵装饰,使得本款连衣裙不显得单调,富有层次感,上身给人一种独特的时尚魅力。长袖的设计,更加的贴合手臂曲线,上身更加的舒适贴身。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#粉色*风格#清新*衣样式#外套*衣样式#西装*衣门襟#双排扣", "response": "这款外套设计成西装的版型,彰显经典优雅的气质,结合了粉色又添清新气息,甜美百搭时尚感满满。利落的版型简洁流畅,亮色双排扣更添精致感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#线条*衣样式#风衣*衣样式#外套*衣门襟#拉链*衣款式#拉链*衣款式#松紧带*衣款式#连帽*衣款式#收腰", "response": "选自品牌江南布衣的一款女士长风衣外套,选用轻薄的<UNK><UNK>,穿着灵活毫无压力。直筒版型简洁利落,长过膝盖的长度穿着个性十足,连帽宽大有型,富有活力,<UNK>拉链开合,拉上拉链有一丝酷劲,敞开穿则更休闲,连帽领翻开修饰颈部线条。松紧带收腰设计,低调的分割上下比例,打造显高小心机。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#街头*风格#简约*图案#刺绣*裤长#短裤*裤款式#钉珠*裤口#毛边", "response": "又到了光腿穿裙子和短裤的时候了,BRAND的这款短裤,采用柔软透气的纯棉牛仔面料,穿着舒适无束缚感。而简约的版型加入了精美的刺绣和钉珠装饰,提升了整体的品质感,显得精美而又立体饱满。搭配下摆的毛边装饰,散发出不羁的街头感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*颜色#黑色*图案#条纹*衣样式#衬衫*衣领型#翻领*衣门襟#系带*衣款式#拼接*衣款式#露肩", "response": "一款老鹰图案露肩衬衫,露肩系带的设计,少女感十足。老鹰图案的设计,更添几分趣味感。条纹面料和牛仔面料的拼接设计,给人一种风度的层次感。小翻领的设计十分的精致,搭配一件黑色打底裤也吸晴万分。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*裙型#百褶*裙长#半身裙*裙款式#拼接*裙款式#腰带", "response": "一款颇有设计感的半身裙,单侧雪纺百褶的拼接设计,规整排列的层次感带来立体效果,增加了裙身的廓形,行走间更是带来柔美的灵动气息,轻而易举穿出优雅的轻熟风,呈现十足的女人味来。同面料延伸处理的半固定腰带,可以自然的垂落下来,也算是为整体打造造型亮点,彰显你独特的时尚品味,迎合早春对轻盈的追求。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#性感*裙型#包臀裙*裙型#鱼尾裙", "response": "修身包臀版型结合性感鱼尾裙摆设计,彰显婉约优雅风情之余,为整体注入几分俏皮灵动气息。且下摆辅以律动感摺裥元素,更烘托出女性浪漫精致的一面。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*颜色#绿色*图案#线条*裙长#连衣裙*裙领型#v领*裙款式#勾花镂空", "response": "连衣裙可以让你在旋转与跳跃间,都散发出万种风情,受到了万千女性的喜爱。这款连衣裙选用绿色调,既散发出活力气息,又增添了高雅的气质。而镂空的钩花设计,则为其增添了浪漫的风情,同时更显美观与时尚。再加上v领的设计,不仅映衬出精致的脸颊,还打造出优美的颈部线条。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#淑女*衣样式#毛衣*衣领型#高领", "response": "高领毛衣一直是网红妹子,因为穿着它有一种淑女甜美气质。它最大的亮点在于它的高领设计和花边装饰。在淑女干练的气质基础上又增加了一些少女的甜美气息,穿着非常有型,最佳搭配小白鞋。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*裤长#九分裤*裤型#阔腿裤*裤款式#拉链*裤腰型#高腰*裤口#开叉", "response": "九分裤长,把妹子的拉长了腿的比例,配合高腰设计,瞬间显得妹子的腿长了很多,一下子自信满满啦。采用侧面隐藏拉链设计,穿脱方便又舒适。设计感十足的开叉裤脚,身上的摩登<UNK>浓了。这个春天妹子的腿型,就交给阔腿裤啦。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "经典的套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#卡通*裙长#连衣裙", "response": "传奇而又经典的卡通形象,米老鼠似乎已经成为了孩童风格的一种标志,大小不一的头像以及奇趣的表情设计。满版的点缀风格让整个连衣裙洋溢着独特的天真气质,加之面料小口袋的点缀,小小的造型呈现出灵巧而又可爱的格调,让宝贝俏皮萌动。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#条纹*图案#刺绣*裙型#背带裙*裙下摆#毛边", "response": "假两件版型的设计,给人一种错觉,大大增添自身时髦感。毛边裙摆的采用,看起来活力十足。设计师解决了以往穿脱不方便的问题,应用的可调节背带设计,非常的人性化。裙子上的花朵刺绣图案,看起来也栩栩如生,同时也展示出了精湛的做工手艺。为了与女人自身清纯的一面形成呼应,应用的条纹图案非常完美。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#街头*风格#休闲*裤长#五分裤*裤腰型#松紧腰", "response": "这款休闲五分裤,采用亲肤软牛仔,洗水磨白形成深浅对比,更加个性。大弹力松紧腰,舒适贴合,一点都不紧勒。裤子门襟时尚的设计,为细节加分。立体双贴袋,腰间系带的点缀更吸睛。精致的裁剪,或是干练整洁的走线和宽松版型,是对街头的描写。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*图案#条纹*图案#蕾丝*衣样式#衬衫*衣领型#立领*衣款式#荷叶边", "response": "条纹衬衫是引领时尚圈的常青树,尤其给人舒适感官享受的蓝白条纹,更是深得时尚icon的喜爱。加之搭载经典的立领秀出纤长的玉颈,更显气质优雅。肩膀上饰有薄薄的蕾丝,打破了条纹衬衫的干练,更添别样风情。荷叶边的蔓延更显气质甜美,自然吸睛无数。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*风格#休闲*图案#线条*裙型#牛仔裙*裙长#半身裙", "response": "牛仔半身裙作为时尚宠儿,一直被很多潮人捧在手心,zui妙的莫过于它<UNK>的时髦感以及百搭易驾驭的属性。裙身设计了自然的腰型,拉长腿部线条,让小仙女们感受到大长腿。以及两侧插袋的造型,显得比较随意,休闲的感觉。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#刺绣*衣样式#卫衣*衣袖型#罗纹袖口", "response": "乍一看很平凡的纯色系卫衣,<UNK>暗藏的刺绣玄机,就足够把时髦的张力表现得不凡。很有包容性的廓形,舒适的罗纹收口,宽大的样子却依旧<UNK>既定的风格,让你的潇洒随性表现得收放自如。呆萌查理的袖间刺绣,极简的漫画笔触巧塑生动有趣的风格。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*衣样式#衬衫*衣样式#风衣*衣款式#抽绳", "response": "风衣带有一种随性大气的感觉,在春风拂面的日子里能衬托出你的气质。草绿色的衣身配色,与与自身清纯干净的性格形成了呼应。具有一定实用性的下摆抽绳,可以让你任意的变换风格。抛弃了衬衫领的设计应用的设计,更能将你帅气的一面展示出来。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*风格#复古*风格#清新*图案#条纹*图案#复古*衣样式#针织衫*衣样式#开衫*衣长#常规*衣款式#拼接*衣款式#纽扣*衣款式#罗纹", "response": "慵懒气质的针织开衫,充满了复古的情调,奶奶级的麻花编织手法,充满立体感的同时保暖效果也是满分。下摆的罗纹拼接,让针织衫回暖性更棒。活泼的<UNK>条纹拼接,跳脱出常规配色,清新色调的选用,更加衬托出肌肤的雪白。精致的纽扣点缀,反光的质感让针织衫充满现代感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#复古*图案#蝴蝶结*图案#复古*图案#波点*衣样式#衬衫*衣领型#立领*衣门襟#系带*衣款式#木耳", "response": "【<UNK>说】<UNK>衬衫,大波点气质复古从立领上延伸的长系带,可轻松绑成蝴蝶结,甜美感加分采用打缆工艺的松紧袖口边边处的木耳<UNK>很可爱"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*风格#青春*图案#字母*图案#文字*裙型#网纱裙*裙袖长#无袖*裙领型#圆领", "response": "大气的圆领舒适贴合,彰显出女孩儿精神的气质。无袖的款式与圆领相迎合,简约的同时又不失时尚风采。前身由可爱蝴蝶图案点缀,亮丽的字母映衬其上,诉说着一丝精美感。橙色网纱裙摆造型优雅唯美,与上身的图案相呼应,十分富有青春的气息,伴随着步伐的行走间,带出一丝别致浪漫的风情。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#丝绒*风格#复古*图案#复古*衣样式#雪纺衫*衣袖型#喇叭袖*衣款式#木耳边*衣款式#飘带*衣款式#荷叶边", "response": "这款雪纺衫,采用具有复古韵味的荷叶边元素,加上丝绒质感的加长飘带,洋溢着浪漫古典的韵味。<UNK>两侧镶有包扣,和立体木耳边装饰,大大提升时髦指数。而流线型喇叭袖设计,充满灵动质感,为造型平添活力。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*版型#显瘦*裙下摆#荷叶边*裙腰型#高腰*裙长#半身裙", "response": "很简洁百搭的一款半裙,裙身荷叶边设计,飘逸灵动,上身更显层次感丰富。高腰造型,版型优良,衬显修长双腿。裙子做的比较宽松,包容性敲好,遮肉效果棒棒的。非常的显瘦哦,选用精品梭织面料,垂感好,肌理细致,上身敲舒服哟。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#青春*风格#性感*图案#线条*裙下摆#开叉*裙长#连衣裙*裙领型#翻领*裙款式#腰带*裙款式#衬衫式", "response": "设计师以衬衫式的创作灵感,巧妙地搬运到连衣裙身上,中性又不失性感;时尚小翻领设计,巧妙衬托颈部线条,彰显青春派的艺术时尚,小资派的精彩演绎。耳目一新的双腰带设计,既突出了腰线又感觉很前卫;下面走心的大开叉设计,更能激发人的好奇心,营造出无人超越的高级性感,只需一眼就令人<UNK>。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*风格#性感*图案#印花*裙下摆#荷叶边*裙长#连衣裙*裙袖型#灯笼袖", "response": "这款连衣裙走的是性感大方的风格路线,展现出你的大大咧咧的性情,非常的有趣。选用了宽松的版型,配合星空印花的图案,塑造出新颖有趣,不失活力四射的印象感。荷叶边的裙摆设计,突显出飘逸性感的一面。配合灯笼袖的袖型细节,体现出<UNK>的一面。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*材质#水洗*颜色#浅色*风格#休闲*风格#性感*图案#线条*裙型#牛仔裙*裙型#直筒裙*裙下摆#开叉*裙下摆#毛边*裙腰型#高腰", "response": "浅色水洗效果牛仔裙,高腰设计融合修身直筒廓形,凸显纤细腰部和迷人翘臀,美化勾勒性感身材曲线。正面开叉细节有效拉长腿部线条,灵动性感。磨毛边下摆设计,带来休闲随性气息。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#条纹*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#罗纹", "response": "这款连帽卫衣自带休闲魅力,将杜嘉班纳的品牌标志以印花的形式装饰在衣身前幅,展现出华丽不失看点的视觉效果,每时每刻都在彰显不凡品味。罗纹条纹袖口和下摆,不仅能使卫衣更帅气惹眼,还能为整体增加一股前卫之风。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#简约*图案#卡通*图案#蝴蝶结*图案#印花*衣样式#衬衫", "response": "大面积的卡通兔子印花,童趣满满,再加上领口的蝴蝶结装饰织带。充满童趣的同时又不失小女生的甜美气息,相当减龄。这款衬衫选用真丝面料,真丝面料不仅轻薄,而且柔滑、亲肤,就好像人的第二层肌肤般带给你清凉舒适的穿着感觉。合身的版型,裁剪得干净利落,简约又不失时尚气息,打造干练的气场。这款衬衫日常十分百搭,不仅可以与其他服饰搭配,作为一件单品也十分出彩。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*材质#水洗*风格#复古*风格#简约*图案#复古*图案#线条*裤长#九分裤*裤款式#不规则*裤口#毛边", "response": "misssixty的这款单品延续经典的九分牛仔裤版型,结合贴合身形的剪裁,展现出柔美修长的腿部线条;不同的位置做了不同程度的水洗复古工艺,使得裤身更加立体厚重;此外,裤脚处采用了微微不规则的毛边剪裁,为简约的整体注入一丝随性之感;再加上<UNK>相互呼应的翅膀状图案点缀,瞬间带来一丝浪漫唯美的味道。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#线条*衣样式#卫衣*衣领型#圆领*衣袖型#收口*衣门襟#套头*衣款式#螺纹", "response": "使用经典的螺纹圆领来展开设计,将衣型打造成套头卫衣的款式,穿着时轻松收口,将颈部线条修饰出挺拔优美的的效果,让穿着更加具有精气神。衣身以纯色作为主色调,配上经典的小企鹅logo,将正面点缀,它拥有一个俏皮的小蝴蝶领结,充满细节感,使得衣身吸睛耀眼。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*裤长#九分裤*裤型#直筒裤", "response": "c小小的这样一条迷人的牛仔裤彰显出你的大气个性,它的别致直筒版型十分的高端迷人,让你吸睛十足。个性九分的版型展示出你的迷人小脚踝。它的大气牛仔材质,十分的舒适洒脱,迷人更有型。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*颜色#纯色*风格#简约*图案#纯色*图案#线条*图案#蕾丝*衣样式#衬衫*衣领型#v领", "response": "一款简约的纯色衬衫,采用了个性的大v领,露出柔美的锁骨和颈部线条,散发出清爽迷人的气质;点缀精美的蕾丝花边装饰,波浪形的花边很有美感,增加了视觉亮点。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#撞色*裙下摆#垂坠*裙长#连衣裙*裙袖长#无袖*裙袖型#收口*裙款式#拼接*裙款式#绑带*裙款式#波浪", "response": "来自奥芝国的推出的无袖连衣裙,精选弹力冰丝材质穿柔软垂坠性很好,适合春夏秋三季穿搭。腰部的撞色波浪纹弹力腰封拼接,并以交叉绑带式收口,修饰腰身轻松穿<UNK>人大长腿。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#针织*颜色#灰色*颜色#深蓝色*图案#线条*衣样式#毛衣", "response": "这是一款专为胖孩子设计的针织毛衣,加肥加大的立体版型,利落有型穿着合体不臃肿,穿着更加帅气显瘦;领口、袖口和下摆收紧处理使衣衣更加利落有型,久穿久洗也不易磨损和变形,颇具品质感;深蓝色的大身巧妙地加入一些灰色线条修饰活泼大方,孩子穿上它,洋溢着青春活力。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*材质#网纱*风格#街头*衣样式#衬衫*衣款式#拼接*衣款式#勾花镂空*衣款式#钉珠", "response": "时髦又帅气的牛仔拼接裙,利用多材质拼接演绎刚柔并济的设计。硬朗的牛仔衬衫以镂空拼接,构造出深浅的色系变化,加上钉珠铆钉的装饰,更是玩味出十足的街头帅气。下身拼接的网纱半裙,层次细腻又丰富,两侧加入牛仔插袋呼应上身面料,带来一体感设计。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*颜色#蓝色*颜色#浅蓝色*风格#性感*裙型#牛仔裙*裙型#包臀裙*裙下摆#开叉*裙款式#拼接*裙款式#纽扣", "response": "mm们<UNK>拼接风呢?这款牛仔裤是非常有趣的拼接风,浅蓝色和原蓝色的牛仔拼接在一起,非常吸引眼球。在左侧的裙摆处还做了开叉设计,微微露出腿部皮肤,展现性感姿态。包臀的设计,凸显圆润的臀部。前幅一排金属纽扣,增添细节感和精致度。喜欢的mm千万不要错过~"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#蕾丝*颜色#粉色*风格#清新*图案#碎花*图案#线条*图案#蕾丝*裙型#a字*裙下摆#花边*裙领型#圆领*裙款式#飘带", "response": "清新的小碎花缀满衣身,以淡雅的粉色调为底色,焕发出甜美温婉的少女气息。简洁的圆领设计,柔化脸部线条,加上蕾丝飘带点缀,更显娇俏减龄。下摆蕾丝花边分割裙裾,转身间将浪漫挥洒。散开的a字裙摆,恰到好处遮住了臀部和腿部粗的部分,有很好的修饰作用。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#淑女*图案#植物*图案#印花*裙型#百褶*裙长#连衣裙*裙领型#娃娃领*裙款式#拼接*裙款式#腰带", "response": "法式浪漫情怀,由这款印花连衣裙为你抒写。蝴蝶花卉印花铺陈裙身,蝴蝶翩跹BRAND花丛,浪漫迷人美如画,法式风情呼之欲出。娃娃领的设计,凸显一身柔美的淑女气质。裙摆百褶的设计,<UNK>飞舞更添灵动飘逸的美。腰带拼接的设计,完美打造显瘦显高的身材比例。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*风格#复古*风格#简约*风格#休闲*图案#复古*图案#线条*图案#印花*裙长#连衣裙", "response": "这一款雪纺连衣裙复古的小立领带来不一样的惊喜,不仅拉伸了脖颈的线条,同时衬托出娇小的脸型。衣身大大的印花很有质感,简约休闲中透露着复古精致的美丽。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#文艺*风格#简约*风格#清新*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "飘飘落落,暖色的布料上纷纷落落的铺着羽毛,灰常有意境的一款连衣裙。羽毛是这款连衣裙最大的亮点,色彩也丰富饱满,凸显的文艺感也灰常强烈,满满的文艺清新气息;简约大方的设计,有种不喧嚣的热烈感;凸显内敛的气质。搭大衣、棉服外套不仅保暖又灰常的有韵味,而且这款不仅做了开扣的设计,还做了隐形的小拉链!是可哺乳的款式,方便孕后哺乳穿,墙裂推荐!"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#简约*风格#青春*风格#清新*风格#性感*图案#条纹*图案#撞色*裙下摆#开叉*裙长#连衣裙*裙款式#拼接*裙款式#吊带", "response": "这款针织吊带连衣裙展现青春时尚的格调,双侧撞色条纹的拼接简约经典,散发出清新爽朗的气息,显得格外惹眼,营造出明媚动人的视觉吸引力。赋予简约的吊带裙满满的活力,开叉的剪裁性感别致,充满小女人的韵味。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#街头*风格#青春*衣样式#t恤*衣领型#圆领", "response": "三叶草的这款体恤面料比较舒适,穿起来也能很好的透气排汗。整体的设计风格就是经典的款式,所以说是街头常年流行的必备。圆领的领口设计在穿脱时起到了方便。同时修饰脸部轮廓,更显小脸。三叶草的标志也是最为独特的品牌标识,穿出了个人的品味。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*风格#休闲*风格#清新*衣样式#外套*衣款式#连帽", "response": "春天家以清新白色为主基调打造的这款外套,整体采用了直筒的极简剪裁配合休闲感的连帽设计,穿着在身上的舒适度较高。设计师为这款上衣做了<UNK>口的袖子和下摆的处理,穿着后对于身形的修饰效果会更为出众,显得较为得体、大方。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#军绿色*风格#复古*风格#文艺*风格#知性*风格#休闲*风格#潮*图案#复古*图案#撞色*衣样式#外套*衣样式#西装*衣领型#西装领*衣长#短款*衣袖型#插肩袖", "response": "短款西装小外套,结合了知性和休闲两种风格,在现代的潮流款式中又融入了淡淡的复古韵味。端庄典雅的军绿色衣身,带着自由舒畅的旅行感,款式上选用利落率性色西装领,宽松闲适的插肩袖,门襟选用撞色的两粒扣设计,复古文艺又简洁随性。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "亮眼的橙红色展示出迎面而来的热情感,衬托肤色白皙红润,在宴会上气场十足。方形的镂空点缀着衣领下方,增加看点散发出小女人的妩媚感。独特的衣袖造型倾斜而下,修饰手臂线条非常修长,在举手投足间优雅又大气。收腰的版型设计修饰腰部线条更纤细,打褶的裙摆在行走时灵动十足,仿佛<UNK>的精灵一般。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "裙子表面的镂空花网就使其充满了很强的设计美感,首先是肩部将落肩袖和镂空图案相结合,白皙的肌肤隐隐约约,而且能够很好的缩小肩宽比例。v型领口修饰拉长颈部线条和显得脸小。裙子做了收腰裁剪,并将腰线提高,轻松拉长下半身身材比例,裙摆也更加挺括,从而能够解决胯宽等身材烦恼。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#涤纶*衣样式#风衣*衣袖型#灯笼袖*衣款式#纽扣*衣款式#飘带", "response": "风衣在摒弃了传统的版型样式,将袖子设计成花苞型的灯笼袖,与春天搭配得恰到好处。并在袖子处装饰了四颗纽扣,采用飘带作为松紧调节,增添层次感更显个性别致。除此之外,风衣采用涤纶材质制成,垂顺感好挺括修身,结合小a字形轮廓,更显身形高挑秀美,并且让矮小个的女性也能撑起风衣的气场。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#英伦*风格#简约*风格#休闲*图案#格子*图案#线条*衣样式#西装*衣领型#翻领*衣门襟#双排扣", "response": "这一款休闲西装简约利落的翻领,可以很好地修饰脸型和颈部线条,显脸小的同时又让脖子看上去更纤细。加上精致的格纹装饰,视觉美丽凸显英伦风。而且双排扣设计,时尚大气美观实用。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#文艺*风格#休闲*风格#性感*裙长#半身裙*裙长#连衣裙*裙款式#拼接", "response": "连衣裙的灵感来自于<UNK>匠人穿着围裙的状态,设计师将针织上衣与半裙结合,整体松软舒适,且不失休闲随性感。裙摆不同材质的拼接,带来丰富的层次细节,让时髦度倍增。偏暗调的配色融入文艺田园气息,显随性姿态。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#复古*图案#复古*裙下摆#荷叶边*裙长#连衣裙", "response": "对于女孩子来说,喜爱连衣裙是与生俱来的!几乎没有问题是一条裙纸<UNK>的~BRAND这款裙子整体的设计有点小复古的感觉,而且艳丽的枣红色也是复古色的代表,上身穿着十分衬肤显白哦。个性而时髦的挂脖式领口露出锁骨很是撩人,另外领口至腰间的衣身前片还加入了很有灵动感的荷叶边作为点缀,瞬间点亮了整体的造型感,由内而外散发的优雅而温柔的气质无人能挡。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#日系*风格#简约*图案#线条", "response": "很喜欢这款简单却不简约的时尚牛仔裤,在夏天可以穿出个性与时尚。整个风格比较偏向于日系的身体,任何妹子都能够轻松驾驭,最重要的是版型。穿上特别修饰腿部的线条,打造出了高挑的身材,让你看起来非常有自信的呢,这手工的工艺凸显出了无限的高级质感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#街头*风格#潮*裙型#a字", "response": "孕期就一定要穿的沉闷单调吗?热爱潮流的怎能束缚自己个性的心呢,这款裙子采用a字型设计,让你搭配更为轻松随意,飘逸的撞色织带设计,即刻将原本沉闷的空气也带动的活跃起来。从街头到<UNK>,尽显潮流个性时尚。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*颜色#浅蓝色*风格#街头*风格#休闲*裤型#直筒裤*裤款式#破洞", "response": "破洞元素已变成彰显个性的元素,这款浅蓝色牛仔裤糅合磨白磨破设计,弥漫摩登个性格调,而且破洞设计,打破裤装闷热形象,休闲时髦;直筒款巧妙糅合酷帅感与时髦感,塑造街头潮人印象。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#雪纺*风格#知性*风格#性感*图案#线条*裤长#连体裤*裤款式#木耳边", "response": "雪纺面料的一袭连体裤,舒适的手感,轻盈的穿着,宽松的版型,让上身穿着没有束缚感。一字肩的设计,木耳的花边,显露颈部柔美的线条,与性感的锁骨,展现女性知性的一面,木耳花边的<UNK>设计,显露穿着的甜美感,与少女味。高收腰的设计,拉伸腰部的曲线,提高腰线,显露穿着高挑的身姿。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#简约*图案#线条*裤款式#口袋*裤款式#拉链", "response": "侧缝处添置有立体拉链口袋作为装饰,实用性强且兼备美观性。净色的大体外观,简约低调,大方得体,易于搭配。裤腰处植入张弛有度的弹性带,贴合腰部,适合于大多数人穿着。衣身剪裁干净利落,线条流畅。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*图案#条纹*图案#线条*衣样式#衬衫", "response": "白色的衬衫采用了百褶的袖子设计,既修饰了手臂线条,又为整体增强了设计感。背带裤是永不过时的条纹款式,加上阔腿裤的设计,更显女性身材。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#简约*风格#休闲*裤长#短裤*裤款式#破洞", "response": "选用优质的纯棉面料打造出舒适的质感,而且上身不会扎身。同时,个性破洞细节设计,增加了牛仔短裤的细节感和吸睛度。此外,简约好搭的配色,柔和你的棱角,让你看起来温柔又平易近人。适合约会等休闲场景,是你衣柜里不可或缺的时髦单品之一。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#水洗*风格#潮*裤款式#不规则*裤口#毛边", "response": "年轻潮流的设计品味,洋气又好穿。细节相当丰富有看点,融入水洗磨白,使其充满时尚不羁的气息。裤脚前后毛边处理,配上不规则脚口,更添青春活力。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*颜色#蓝色*风格#简约*裤型#背带裤*裤款式#纽扣", "response": "背带裤的选用天蓝色的主题,远远看上去就像是蓝色<UNK>悬挂在活跃孩子的身上。简约的背带设计,可随时拆开的纽扣,让稚嫩孩子穿衣时不费吹灰之力。腰部更是搭配弹性材料缝制的腰带,不仅方便穿戴而且完美的起到了修饰作用。后背交叉背带,更是独特新颖的处理,更好更牢固的穿搭,不易滑落。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#涤纶*裤款式#拼接*裤款式#口袋", "response": "前置的口袋盖拼接,为本来单调的设计布局增加了亮点,更突出了裤子的个性化特点。加上精致的涤纶梭织面料制作,具备更加亲肤不刺激的丝质般触感,给你带来更加柔软舒适的穿着体验。其良好的透气性,有效提升了裤子的吸湿排汗性能,为你提供更加清爽舒适的体感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙下摆#荷叶边*裙领型#圆领*裙袖型#收口*裙款式#螺纹", "response": "此款上衣采用了经典的圆领款式打造,贴合舒适并能修饰出完美的脸型。同时螺纹的收口贴合身材更完美,在前身处采用了可爱的小狮子造型,带<UNK>真的感觉,而狮子的毛发更是立体精致,显得真实又有丰富的层次。裙身的下摆处采用了荷叶边的设计,俏皮活泼更可爱。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*材质#网纱*风格#青春*图案#印花*衣样式#衬衫*衣领型#v领*衣款式#拼接", "response": "这一款衬衫交叠v领的设计,修饰脖颈尽显女人味,宽松的廓形,穿上非常轻松有范毫不拘束,并很好的遮盖身材,非常显瘦。时尚的网纱拼接,自然美感特别出彩。精致印花,青春减龄特别活力。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#运动*风格#性感*衣样式#西装*衣领型#一字领*衣款式#荷叶边", "response": "荷叶边能够表达出女性的优雅,BRAND的这款上衣,将荷叶边很好地运动起来。性感的一字肩设计,荷叶边从一侧手臂的手肘从前胸绕到另一侧,有着前短后长的感觉,自然垂坠很有层次感,举手投足之间,灵动而优雅。西装袖很好地融合,优雅之中透着小帅气。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#运动*风格#休闲*风格#青春*图案#字母*图案#形状*图案#文字*图案#刺绣*图案#撞色*衣样式#卫衣*衣袖型#收口*衣款式#抽绳*衣款式#连帽", "response": "这款dolce&gabbana的连帽卫衣,撞色的<UNK>字母加上桃心形状的刺绣图案令人耳目一新,举手投足间散发阳光活力少女的青春气息;连帽款式尽显帅气利落风范,细节上采用抽绳处理实用又美观,洋溢满满的运动休闲范儿;加之袖口处的收口设计别出心裁,宽松的衣身烘托出慵懒率性的格调。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#休闲*图案#字母*图案#文字*图案#线条*图案#印花*图案#撞色*裤款式#拼接*裤口#小脚", "response": "上下<UNK>拼接撞色设计,吸睛十足,轻松聚焦视线,个性前卫。字母印花设计,巧添时尚细节看点,以鲜明撞色渲染,展现年轻活力气息。长袖套头轮廓,线条处理恰到好处,呼应休闲基调。宽松的版型,不挑身材,上身好看。连帽的设计美观实用,防风保暖。时尚百搭,可以搭配牛仔裤、紧身裤、休闲裤、束脚裤等。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#灰色*风格#复古*风格#文艺*风格#青春*图案#卡通*图案#复古*衣样式#风衣*衣长#中长款", "response": "一款好看的风衣大概能为这个姹紫嫣红的春天多一份色彩,沉静的灰色上身具有非常好的效果,显得热更加内敛沉稳,有一股淡淡的复古文艺风格。而中长的版型自然下垂,修身显高又瞬间提升气场。后背的卡通图案别致可爱,更添青春气息。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#性感*衣样式#针织衫*衣款式#露肩", "response": "这一款针织衫露肩设计,风情浪漫性感迷人。略微宽松的版型舒适随意,很好的掩饰身材小小的缺陷,看起来精致高挑。加上时尚的花边下摆,错落有致视觉美丽。精致袖口,修饰手臂特别出彩。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#松紧带*裙款式#飘带", "response": "<UNK>冷风的气质感,干净利落的feel,小露香肩有一种含蓄撩拨的趣味,袖口领口的飘带设计很是巧妙,让整个小衫更加优美,领子部分的两边肩部松紧带设计,大胆随意的穿出多种效果。让仙女们走在时尚<UNK>的道路上更加自信。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*衣样式#卫衣*衣领型#圆领", "response": "针织卫衣采用了简洁的圆领设计,非常百搭,免去了你<UNK>找不到搭配的烦恼。合体的剪裁设计,让你在跑步健身时轻巧灵便,活动自如,达到更好的锻炼效果。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#白色*风格#简约*图案#蝴蝶结*图案#刺绣*衣样式#衬衫*衣袖型#喇叭袖", "response": "这一款很好穿的白色衬衫,利落的宽松版型几乎是不挑身材的,无门襟的设计也符合整体的气息。胸前做了绣花的点缀,为简约的衬衫增添了几分柔美的气质。七分的喇叭袖露出小臂,蝴蝶结的点缀显得气质更加的浪漫。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*颜色#灰色*风格#英伦*风格#复古*图案#格子*图案#复古*裙型#百褶*裙长#半身裙*裙款式#波浪*裙款式#收腰", "response": "BRAND这款半身裙,用复古的灰色格纹,打造出十足英伦范儿。搭配百褶裙身,为整体增添层次感,穿出减龄风。同时,波浪边的收腰设计,不仅更好的修饰腰部曲线,还为整体气质增添了优雅美感。而雪纺面料,使你在夏日也能穿出清爽感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#复古*风格#文艺*风格#中国风*风格#性感*图案#复古*图案#刺绣*裙型#a字*裙领型#v领", "response": "超级具有中国风气息的一款裙子,带着古典的柔婉。花朵刺绣的运用,色彩缤纷靓丽,冲击视觉,演绎复古文艺范儿。经典的气质v领,既凸显了小性感与时尚,又起到点睛的效果。腰部系的设计,配上a字版型,显瘦又遮肚子。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#性感*图案#线条*衣样式#马甲*衣领型#翻领*衣款式#露背*衣款式#绑带*衣款式#吊带*衣款式#收腰", "response": "小吊带马甲叠穿造型,年轻而不失时尚格调,有着绑带收腰设计,强调出纤细的腰肢,摩登帅气;小翻领露出纤细修长的脖颈线条,散发清爽利落的小清新气息;性感交叉露背设计,别致吸睛,女人味十足;高腰伞形裙摆自然撑开,上身塑造黄金比例,突显得腰更细,巧妙地修饰身型。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#黑白*风格#英伦*风格#简约*图案#格子*图案#线条*衣样式#外套*衣样式#西装*衣门襟#一粒扣", "response": "这款西装外套,版型加长修身,能更好凸显成熟与稳重。细细密密的黑白图案,远远看形成自然的格纹,时髦英伦范儿。平整肩线将线条感拉伸,让身姿显得更挺拔有型。一粒扣设计,简约大气。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#潮*图案#线条*图案#撞色*裙领型#圆领", "response": "采用经典的圆领设计,修饰颈部线条的同时,且上身穿着舒适不易变形,轻松演绎时髦造型。大面积撞色贴花装饰,无疑是点睛之笔,为简洁的款式轮廓带来了更多的视觉层次感与潮流气息。与众不同的你,不在畏惧撞衫的尴尬。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#雪纺*颜色#纯色*风格#清新*图案#纯色*图案#碎花*衣样式#衬衫*衣款式#荷叶边", "response": "这件荷叶边雪纺碎花衬衫和其他的碎花衬衫相比整体的风格会更优雅柔美一些。颜色上也是比较<UNK>清新的花型配色和纯色的大身相结合,会让人看着很舒服,而且每个碎花之间都会<UNK>限视觉上不会觉得太紧密,更有法式的浪漫优雅。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*风格#简约*风格#青春*风格#潮*风格#性感*图案#线条*图案#蕾丝*衣样式#雪纺衫*衣领型#圆领*衣款式#勾花镂空", "response": "这款时尚镂空雪纺衫,带有性感蕾丝工艺,精致百搭的圆领设计,彰显显独特的质感。背部线条流畅,笔挺而有型干练,彰显潮流时尚之风。走线十分笔直,针脚均匀,尽显裁缝之细致。简约时尚的透视蕾丝袖口,彰显成熟又不乏活力的青春气质。青春优雅的独特风格,流露出满满的潮流感。"} +{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#休闲*裤长#短裤", "response": "来自英国<UNK>的这款儿童休闲短裤,以趣味的小恐龙图案满印裤身,可爱童真,彰显出宝宝的活泼天真范儿。柔软的全棉布料质地,手感细腻顺滑,亲和宝宝的肌肤,带来舒适自在的穿着体验。"} \ No newline at end of file diff --git a/limitations/factual_error.png b/limitations/factual_error.png new file mode 100644 index 0000000000000000000000000000000000000000..fc3a9b669dc94045aa0f59de0b124c72a61296e1 Binary files /dev/null and b/limitations/factual_error.png differ diff --git a/limitations/math_error.png b/limitations/math_error.png new file mode 100644 index 0000000000000000000000000000000000000000..02a496f2c227116205d9d8343a82d0d5af0bb355 Binary files /dev/null and b/limitations/math_error.png differ diff --git a/limitations/self-confusion_google.jpg b/limitations/self-confusion_google.jpg new file mode 100644 index 0000000000000000000000000000000000000000..162665c8647669f4d0ed2eb181f2ebdcdab9700f Binary files /dev/null and b/limitations/self-confusion_google.jpg differ diff --git a/limitations/self-confusion_openai.jpg b/limitations/self-confusion_openai.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b78b99fe5c714bdb214b45d4874d134ede09ce8b Binary files /dev/null and b/limitations/self-confusion_openai.jpg differ diff --git a/limitations/self-confusion_tencent.jpg b/limitations/self-confusion_tencent.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47f89ff75f20f8b89b86e4cbb2beb6cc6b5a9719 Binary files /dev/null and b/limitations/self-confusion_tencent.jpg differ diff --git a/ptuning/README.md b/ptuning/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f92a328f8b681c3c87b13152d5390252740f8263 --- /dev/null +++ b/ptuning/README.md @@ -0,0 +1,213 @@ +# ChatGLM-6B-PT +本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。 + +下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 + +*Read this in [English](README_en.md). + +## 软件依赖 +运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖 +``` +pip install rouge_chinese nltk jieba datasets +``` +## 使用方法 + +### 下载数据集 +ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。 + +```json +{ + "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", + "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" +} +``` + +从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。 + +### 训练 + +#### P-Tuning v2 + +运行以下指令进行训练: +```shell +bash train.sh +``` +`train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。 + +在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。 + +如果你想要[从本地加载模型](../README_en.md#load-the-model-locally),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。 + +#### Finetune + +如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令: + +```shell +bash ds_train_finetune.sh +``` + +### 推理 + +在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数: + +```shell +--model_name_or_path THUDM/chatglm-6b +--ptuning_checkpoint $CHECKPOINT_PATH +``` + +仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`: + +```shell +--model_name_or_path $CHECKPOINT_PATH +``` + +评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 +`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。 + +### 例子 +#### 示例1 +* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 +* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 +* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 +* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 + +#### 示例2 + +* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 +* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 +* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 +* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 + +### 评估结果 + +| | Finetune | P-tuning v2 | LoRA | +| ------------- | ----------- | ----- | ------------- | +| BLEU-4 | 8.01 | 8.10 | 7.62 | +| Rouge-1 | 31.23 | 31.12 | 30.60 | +| Rouge-2 | 7.36 | 7.11 | 6.96 | +| Rouge-l | 25.08 | 24.97 | 24.80 | +| Training Loss | 3.00 | 3.74 | 3.32 | + + + +#### 实验设置 + +``` +max_source_length=64 +max_target_length=64 +max_steps=3000 +``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +per_device_train_batch_size=16 +gradient_accumulation_steps=1 +``` + +##### Finetune + +``` +learning_rate=1e-4 +fp16 +num_gpus=4 +per_device_train_batch_size=4 +gradient_accumulation_steps=1 +``` + +##### LoRA + +实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + +``` +learning_rate=5e-4 +per_device_train_batch_size=16 +gradient_accumulation_steps=1 +``` + +## 模型部署 +首先载入Tokenizer: + +```python +from transformers import AutoConfig, AutoModel, AutoTokenizer + +# 载入Tokenizer +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +``` + +1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数): + +```python +config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) +prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) +``` +注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。 + +2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint: + +```python +model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) +``` + +之后根据需求可以进行量化,也可以直接使用: + +```python +# Comment out the following line if you don't use quantization +model = model.quantize(4) +model = model.half().cuda() +model.transformer.prefix_encoder.float() +model = model.eval() + +response, history = model.chat(tokenizer, "你好", history=[]) +``` + +**[23/04/19]** 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py) +```shell +bash web_demo.sh +``` +可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。 + +## 使用自己的数据集 +修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。 + +## 对话数据集 + +如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据: + +```json lines +{"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []} +{"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]} +{"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]} +``` + +训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。 + +可以参考以下指令: + +```shell +bash train_chat.sh +``` + +## 引用 + +``` +@inproceedings{liu2022p, + title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, + author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages={61--68}, + year={2022} +} +``` + + + diff --git a/ptuning/README_en.md b/ptuning/README_en.md new file mode 100644 index 0000000000000000000000000000000000000000..34a68a637e4f5e997606907d107d03d725ae3db6 --- /dev/null +++ b/ptuning/README_en.md @@ -0,0 +1,206 @@ +# ChatGLM-6B-PT +This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run. + +The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code. + +## Software dependencies +Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required +``` +pip install rouge_chinese nltk jieba datasets +``` +## Instructions + +### Download the dataset +The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content). + +```json +{ + "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", + "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" +} +``` + +From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory. + +### Training + +#### P-Tuning v2 + +Run the following commands for training: +```shell +bash train.sh +``` +`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision. + +Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation. + +If you want to [load the model locally](../README_en.md#load-the-model-locally), you can change `THUDM/chatglm-6b` in `train.sh` to your local model path. + +#### Finetune +To finetune the full parameters, you need to install [Deepspeed](https://github.com/microsoft/DeepSpeed), and then run the following command: + +```shell +bash ds_train_finetune.sh +``` + +### Inference + +During P-tuning v2 training, the model only saves the parameters of the PrefixEncoder part, so the original ChatGLM-6B model and the weight of the PrefixEncoder need to be loaded at the same time during inference, and the arguments need to be specified in `evaluate.sh`: + +```shell +--model_name_or_path THUDM/chatglm-6b +--ptuning_checkpoint $CHECKPOINT_PATH +``` + +It is still compatible with the old version of Checkpoint saved with full parameters, just set `model_name_or_path` as before: + +```shell +--model_name_or_path $CHECKPOINT_PATH +``` + +The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in +`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`. + +### Example +#### Example 1 +* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 +* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 +* Output[before tuning]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 +* Output[after tuning]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 + +#### Example 2 + +* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 +* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 +* Output[before tuning]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 +* Output[after tuning]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 + +### evaluation result + +| | Finetune | P-tuning v2 | LoRA | +| ------------- | ----------- | ----- | ------------- | +| BLEU-4 | 8.01 | 8.10 | 7.62 | +| Rouge-1 | 31.23 | 31.12 | 30.60 | +| Rouge-2 | 7.36 | 7.11 | 6.96 | +| Rouge-l | 25.08 | 24.97 | 24.80 | +| Training Loss | 3.00 | 3.74 | 3.32 | + +#### Experiment Settings + +``` +max_source_length=64 +max_target_length=64 +max_steps=3000 +``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +per_device_train_batch_size=16 +gradient_accumulation_steps=1 +``` + +##### Finetune + +``` +learning_rate=1e-4 +fp16 +num_gpus=4 +per_device_train_batch_size=4 +gradient_accumulation_steps=1 +``` + +##### LoRA + +The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + +``` +learning_rate=5e-4 +per_device_train_batch_size=16 +gradient_accumulation_steps=1 +``` + +## Model Deployment +First load the tokenizer: + +```python +from transformers import AutoConfig, AutoModel, AutoTokenizer + +# Load Tokenizer +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +``` + +1. If a new Checkpoint needs to be loaded (only contains the PrefixEncoder parameter): + +```python +config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) +prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) +``` +Note that you may need to change `pre_seq_len` to the actual value of your training. If you [load model from local](../README_en.md#load-the-model-locally), you need to change `THUDM/chatglm-6b` to the local model path (not the checkpoint path). + +2. If you need to load the old checkpoint (including both ChatGLM-6B and PrefixEncoder parameters), or perform full parameter fine-tuning, then directly load the entire checkpoint: + +```python +model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) +``` + +Then it can be quantified according to the needs, or it can be used directly: + +```python +# Comment out the following line if you don't use quantization +model = model. quantize(4) +model = model.half().cuda() +model.transformer.prefix_encoder.float() +model = model.eval() + +response, history = model.chat(tokenizer, "Hello", history=[]) +``` + +**[23/04/19]** You can also directly run [web demo](./web_demo.py) which supports loading P-Tuning v2 checkpoint +```shell +bash web_demo.sh +``` +It may be necessary to modify the content of [web_demo.sh](./web_demo.sh) to match your actual checkpoint situation. + +## Use your own dataset +Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text. +You may also need to increase `max_source_length` and `max_target_length` to match the maximum input and output lengths in your own dataset. + +## Dialog Dataset + +If you need to use multiple rounds of dialogue data to train the model, you can provide chat history. For example, the following is the training data for a three-round dialogue: + +```json lines +{"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []} +{"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]} +{"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]} +``` + +During training, you need to specify `--history_column` as the key of the chat history in the data (`history` in this example), and the chat history will be stitched automatically. Note that content exceeding the input length `max_source_length` will be truncated. + +You can refer to the following instructions: + +```shell +bash train_chat.sh +``` + +## Citation + +``` +@inproceedings{liu2022p, + title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, + author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages={61--68}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptuning/arguments.py b/ptuning/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..fda1f3522261f50768984402d9ac691557ea63f3 --- /dev/null +++ b/ptuning/arguments.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + ptuning_checkpoint: str = field( + default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + resize_position_embeddings: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to automatically resize the position embeddings if `max_source_length` exceeds " + "the model's position embeddings." + ) + }, + ) + quantization_bit: Optional[int] = field( + default=None + ) + pre_seq_len: Optional[int] = field( + default=None + ) + prefix_projection: bool = field( + default=False + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + prompt_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + response_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, + ) + history_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the history of chat."}, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": ( + "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." + ) + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": ( + "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + ) + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": ( + "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + ) + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + ) + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The token to force as the first generated token after the decoder_start_token_id." + "Useful for multilingual models like mBART where the first generated token" + "needs to be the target language token (Usually it is the target language token)" + ) + }, + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: + raise ValueError("Need either a dataset name or a training/validation/test file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + diff --git a/ptuning/deepspeed.json b/ptuning/deepspeed.json new file mode 100644 index 0000000000000000000000000000000000000000..798932966f38b2df8a468c72a4b41d8b47033ccc --- /dev/null +++ b/ptuning/deepspeed.json @@ -0,0 +1,21 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients" : true + } +} \ No newline at end of file diff --git a/ptuning/ds_train_finetune.sh b/ptuning/ds_train_finetune.sh new file mode 100644 index 0000000000000000000000000000000000000000..531a8004dbed00819aa767c420cdc483e7c0abed --- /dev/null +++ b/ptuning/ds_train_finetune.sh @@ -0,0 +1,28 @@ + +LR=1e-4 + +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \ + --deepspeed deepspeed.json \ + --do_train \ + --train_file AdvertiseGen/train.json \ + --test_file AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir ./output/adgen-chatglm-6b-ft-$LR \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --predict_with_generate \ + --max_steps 5000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --fp16 + diff --git a/ptuning/evaluate.sh b/ptuning/evaluate.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab855367009f472c84d095b62b3c3d49a0c5518c --- /dev/null +++ b/ptuning/evaluate.sh @@ -0,0 +1,21 @@ +PRE_SEQ_LEN=128 +CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 +STEP=3000 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_predict \ + --validation_file AdvertiseGen/dev.json \ + --test_file AdvertiseGen/dev.json \ + --overwrite_cache \ + --prompt_column content \ + --response_column summary \ + --model_name_or_path THUDM/chatglm-6b \ + --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \ + --output_dir ./output/$CHECKPOINT \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_eval_batch_size 1 \ + --predict_with_generate \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 diff --git a/ptuning/evaluate_finetune.sh b/ptuning/evaluate_finetune.sh new file mode 100644 index 0000000000000000000000000000000000000000..e275c3cbbec9ee65ad5e4a958a0ea52c248964c4 --- /dev/null +++ b/ptuning/evaluate_finetune.sh @@ -0,0 +1,18 @@ +CHECKPOINT=adgen-chatglm-6b-ft-1e-4 +STEP=3000 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_predict \ + --validation_file AdvertiseGen/dev.json \ + --test_file AdvertiseGen/dev.json \ + --overwrite_cache \ + --prompt_column content \ + --response_column summary \ + --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \ + --output_dir ./output/$CHECKPOINT \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_eval_batch_size 1 \ + --predict_with_generate \ + --fp16_full_eval diff --git a/ptuning/main.py b/ptuning/main.py new file mode 100644 index 0000000000000000000000000000000000000000..17e18b5bed35c37b0419eba5f475bdc3b1d9bdbf --- /dev/null +++ b/ptuning/main.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import os +import sys +import json + +import numpy as np +from datasets import load_dataset +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +import torch + +import transformers +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainingArguments, + set_seed, +) +from trainer_seq2seq import Seq2SeqTrainer + +from arguments import ModelArguments, DataTrainingArguments + +logger = logging.getLogger(__name__) + +def main(): + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + # datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Load dataset + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + + raw_datasets = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Load pretrained model and tokenizer + config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + + if model_args.ptuning_checkpoint is not None: + # Evaluation + # Loading extra state dict of prefix encoder + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + else: + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + + if model_args.quantization_bit is not None: + print(f"Quantized to {model_args.quantization_bit} bit") + model = model.quantize(model_args.quantization_bit) + if model_args.pre_seq_len is not None: + # P-tuning v2 + model = model.half() + model.transformer.prefix_encoder.float() + else: + # Finetune + model = model.float() + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = raw_datasets["train"].column_names + elif training_args.do_eval: + column_names = raw_datasets["validation"].column_names + elif training_args.do_predict: + column_names = raw_datasets["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # Get the column names for input/target. + prompt_column = data_args.prompt_column + response_column = data_args.response_column + history_column = data_args.history_column + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + + def preprocess_function_eval(examples): + inputs, targets = [], [] + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs.append(prompt) + targets.append(examples[response_column][i]) + + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) + labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) + + if data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def preprocess_function_train(examples): + max_seq_length = data_args.max_source_length + data_args.max_target_length + + model_inputs = { + "input_ids": [], + "labels": [], + } + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query, answer = examples[prompt_column][i], examples[response_column][i] + + if history_column is None: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > data_args.max_source_length - 1: + a_ids = a_ids[: data_args.max_source_length - 1] + + if len(b_ids) > data_args.max_target_length - 2: + b_ids = b_ids[: data_args.max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) + + context_length = input_ids.index(tokenizer.bos_token_id) + mask_position = context_length - 1 + labels = [-100] * context_length + input_ids[mask_position+1:] + + pad_len = max_seq_length - len(input_ids) + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + labels = labels + [tokenizer.pad_token_id] * pad_len + if data_args.ignore_pad_token_for_loss: + labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) + + return model_inputs + + def print_dataset_example(example): + print("input_ids",example["input_ids"]) + print("inputs", tokenizer.decode(example["input_ids"])) + print("label_ids", example["labels"]) + print("labels", tokenizer.decode(example["labels"])) + + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + preprocess_function_train, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + print_dataset_example(train_dataset[0]) + + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + print_dataset_example(eval_dataset[0]) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = raw_datasets["test"] + if data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + with training_args.main_process_first(desc="prediction dataset map pre-processing"): + predict_dataset = predict_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) + print_dataset_example(predict_dataset[0]) + + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=None, + padding=False + ) + + # Metric + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if data_args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + score_dict = { + "rouge-1": [], + "rouge-2": [], + "rouge-l": [], + "bleu-4": [] + } + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + training_args.generation_num_beams = ( + data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + ) + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + save_prefixencoder=model_args.pre_seq_len is not None + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + # elif last_checkpoint is not None: + # checkpoint = last_checkpoint + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + train_result = trainer.train(resume_from_checkpoint=checkpoint) + # trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + results = {} + max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Predict ***") + predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) + metrics = predict_results.metrics + max_predict_samples = ( + data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) + ) + metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + if trainer.is_world_process_zero(): + if training_args.predict_with_generate: + predictions = tokenizer.batch_decode( + predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + predictions = [pred.strip() for pred in predictions] + labels = tokenizer.batch_decode( + predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + labels = [label.strip() for label in labels] + output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") + with open(output_prediction_file, "w", encoding="utf-8") as writer: + for p, l in zip(predictions, labels): + res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) + writer.write(f"{res}\n") + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/ptuning/train.sh b/ptuning/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..efc9a16c3eeabbd75a8870a81869cb8288a0bc05 --- /dev/null +++ b/ptuning/train.sh @@ -0,0 +1,26 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_train \ + --train_file AdvertiseGen/train.json \ + --validation_file AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 3000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 + diff --git a/ptuning/train_chat.sh b/ptuning/train_chat.sh new file mode 100644 index 0000000000000000000000000000000000000000..2309f5a0d1fa9f4ad6c2effa7232ede66ae43746 --- /dev/null +++ b/ptuning/train_chat.sh @@ -0,0 +1,27 @@ +PRE_SEQ_LEN=128 +LR=1e-2 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_train \ + --train_file $CHAT_TRAIN_DATA \ + --validation_file $CHAT_VAL_DATA \ + --prompt_column prompt \ + --response_column response \ + --history_column history \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir $CHECKPOINT_NAME \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 3000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 + diff --git a/ptuning/trainer.py b/ptuning/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..63101bc9d3dfb65ff5a444c7c151b8d4d241f2c9 --- /dev/null +++ b/ptuning/trainer.py @@ -0,0 +1,3830 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + +import contextlib +import functools +import glob +import inspect +import math +import os +import random +import re +import shutil +import sys +import time +import warnings +from collections.abc import Mapping +from distutils.util import strtobool +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from tqdm.auto import tqdm + + +# Integrations must be imported before ML frameworks: +# isort: off +from transformers.integrations import ( + default_hp_search_backend, + get_reporting_integration_callbacks, + hp_params, + is_fairscale_available, + is_optuna_available, + is_ray_tune_available, + is_sigopt_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) + +# isort: on + +import numpy as np +import torch +import torch.distributed as dist +from huggingface_hub import Repository, create_repo +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +from transformers import __version__ +from transformers.configuration_utils import PretrainedConfig +from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled +from transformers.dependency_versions_check import dep_version_check +from transformers.modelcard import TrainingSummary +from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES +from transformers.optimization import Adafactor, get_scheduler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + DistributedTensorGatherer, + IterableDatasetShard, + LabelSmoother, + LengthGroupedSampler, + SequentialDistributedSampler, + ShardSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_module_class_from_name, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_truncate, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + FSDPOption, + HPSearchBackend, + HubStrategy, + IntervalStrategy, + PredictionOutput, + RemoveColumnsCollator, + ShardedDDPOption, + TrainerMemoryTracker, + TrainOutput, + default_compute_objective, + default_hp_space, + denumpify_detensorize, + enable_full_determinism, + find_executable_batch_size, + get_last_checkpoint, + has_length, + number_of_arguments, + seed_worker, + set_seed, + speed_metrics, +) +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from transformers.utils import ( + CONFIG_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + can_return_loss, + find_labels, + get_full_repo_name, + is_accelerate_available, + is_apex_available, + is_datasets_available, + is_in_notebook, + is_ipex_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_compile_available, + is_torch_neuroncore_available, + is_torch_tpu_available, + logging, +) +from transformers.utils.generic import ContextManagers + + +_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +if is_in_notebook(): + from transformers.utils.notebook import NotebookProgressCallback + + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +if is_torch_tpu_available(check_device=False): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + +if is_fairscale_available(): + dep_version_check("fairscale") + import fairscale + from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP + from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP + from fairscale.nn.wrap import auto_wrap + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +skip_first_batches = None +if is_accelerate_available(): + from accelerate import __version__ as accelerate_version + + if version.parse(accelerate_version) >= version.parse("0.16"): + from accelerate import skip_first_batches + + +if TYPE_CHECKING: + import optuna + +logger = logging.get_logger(__name__) + + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" + + +class Trainer: + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + + Args: + model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): + The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. + + <Tip> + + [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use + your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers + models. + + </Tip> + + args ([`TrainingArguments`], *optional*): + The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the + `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. + data_collator (`DataCollator`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + default to [`default_data_collator`] if no `tokenizer` is provided, an instance of + [`DataCollatorWithPadding`] otherwise. + train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): + The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. + tokenizer ([`PreTrainedTokenizerBase`], *optional*): + The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the + maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an + interrupted training or reuse the fine-tuned model. + model_init (`Callable[[], PreTrainedModel]`, *optional*): + A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start + from a new instance of the model as given by this function. + + The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to + be able to choose different architectures according to hyper parameters (such as layer count, sizes of + inner layers, dropout probabilities etc). + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. + callbacks (List of [`TrainerCallback`], *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](callback). + + If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple + containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model + and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + + Important attributes: + + - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] + subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, + the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner + model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to `False` if model parallel or deepspeed is used, or if the default + `TrainingArguments.place_model_on_device` is overridden to return `False` . + - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while + in `train`) + + """ + + from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + save_prefixencoder: bool = False, + ): + self.save_prefixencoder = save_prefixencoder + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = TrainingArguments(output_dir=output_dir) + self.args = args + # Seed must be set before instantiating the model when using model + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + + # memory metrics - must set up as early as possible + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + + # set the correct log level depending on the node + log_level = args.get_process_log_level() + logging.set_verbosity(log_level) + + # force device and distributed setup init explicitly + args._setup_devices + + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" + " overwrite your model when calling the `train` method. This will become a fatal error in the next" + " release.", + FutureWarning, + ) + self.model_init = model_init + + if model.__class__.__name__ in MODEL_MAPPING_NAMES: + raise ValueError( + f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " + "computes hidden states and does not accept any labels. You should choose a model with a head " + "suitable for your task like any of the `AutoModelForXxx` listed at " + "https://huggingface.co/docs/transformers/model_doc/auto." + ) + + if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: + self.is_model_parallel = True + else: + self.is_model_parallel = False + + # At this stage the model is already loaded + if getattr(model, "is_loaded_in_8bit", False): + if getattr(model, "_is_int8_training_enabled", False): + logger.info( + "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + " inside the model such as adapters using `peft` library and freeze the model weights. Please" + " check " + " the examples in https://github.com/huggingface/peft for more details." + ) + else: + raise ValueError( + "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" + " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + ) + + # Setup Sharded DDP training + self.sharded_ddp = None + if len(args.sharded_ddp) > 0: + if args.deepspeed: + raise ValueError( + "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if len(args.fsdp) > 0: + raise ValueError( + "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." + ) + + if args.local_rank == -1: + raise ValueError("Using sharded DDP only works in distributed training.") + elif not is_fairscale_available(): + raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") + elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: + raise ImportError( + "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " + f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." + ) + elif ShardedDDPOption.SIMPLE in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.SIMPLE + elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 + elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 + + self.fsdp = None + if len(args.fsdp) > 0: + if args.deepspeed: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if not args.fsdp_config["xla"] and args.local_rank == -1: + raise ValueError("Using fsdp only works in distributed training.") + + # dep_version_check("torch>=1.12.0") + # Would have to update setup.py with torch>=1.12.0 + # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 + # below is the current alternative. + if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): + raise ValueError("FSDP requires PyTorch >= 1.12.0") + + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy + + if FSDPOption.FULL_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.FULL_SHARD + elif FSDPOption.SHARD_GRAD_OP in args.fsdp: + self.fsdp = ShardingStrategy.SHARD_GRAD_OP + elif FSDPOption.NO_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.NO_SHARD + + self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE + if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch: + self.backward_prefetch = BackwardPrefetch.BACKWARD_POST + + self.forword_prefetch = False + if self.args.fsdp_config.get("forword_prefect", False): + self.forword_prefetch = True + + self.limit_all_gathers = False + if self.args.fsdp_config.get("limit_all_gathers", False): + self.limit_all_gathers = True + + # one place to sort out whether to place the model on device or not + # postpone switching model to cuda when: + # 1. MP - since we are trying to fit a much bigger than 1 gpu model + # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, + # and we only use deepspeed for training at the moment + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first + # 4. Sharded DDP - same as MP + # 5. FSDP - same as MP + self.place_model_on_device = args.place_model_on_device + if ( + self.is_model_parallel + or args.deepspeed + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) + or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) + or (self.fsdp is not None) + ): + self.place_model_on_device = False + + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + + if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): + self._move_model_to_device(model, args.device) + + # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs + if self.is_model_parallel: + self.args._n_gpu = 1 + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + + self.compute_metrics = compute_metrics + self.preprocess_logits_for_metrics = preprocess_logits_for_metrics + self.optimizer, self.lr_scheduler = optimizers + if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): + raise RuntimeError( + "Passing a `model_init` is incompatible with providing the `optimizers` argument. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + if is_torch_tpu_available() and self.optimizer is not None: + for param in self.model.parameters(): + model_device = param.device + break + for param_group in self.optimizer.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + if model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you" + " created an optimizer around your model **before** putting on the device and passing it to the" + " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" + " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." + ) + if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( + self.optimizer is not None or self.lr_scheduler is not None + ): + raise RuntimeError( + "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + + # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. + self._loggers_initialized = False + + # Create clone of distant repo and output directory if needed + if self.args.push_to_hub: + self.init_git_repo(at_init=True) + # In case of pull, we need to make sure every process has the latest. + if is_torch_tpu_available(): + xm.rendezvous("init git repo") + elif args.local_rank != -1: + dist.barrier() + + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") + + if args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: + raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") + + if ( + train_dataset is not None + and isinstance(train_dataset, torch.utils.data.IterableDataset) + and args.group_by_length + ): + raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") + + self._signature_columns = None + + # Mixed precision setup + self.use_apex = False + self.use_cuda_amp = False + self.use_cpu_amp = False + + # Mixed precision setup for SageMaker Model Parallel + if is_sagemaker_mp_enabled(): + # BF16 + model parallelism in SageMaker: currently not supported, raise an error + if args.bf16: + raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") + + if IS_SAGEMAKER_MP_POST_1_10: + # When there's mismatch between SMP config and trainer argument, use SMP config as truth + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," + f"but FP16 provided in trainer argument is {args.fp16}," + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + else: + # smp < 1.10 does not support fp16 in trainer. + if hasattr(smp.state.cfg, "fp16"): + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." + ) + + if args.fp16 or args.bf16: + if args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + elif _is_native_cpu_amp_available: + args.half_precision_backend = "cpu_amp" + else: + raise ValueError("Tried to use cpu amp but native cpu amp is not available") + else: + args.half_precision_backend = "cuda_amp" + + logger.info(f"Using {args.half_precision_backend} half precision backend") + + self.do_grad_scaling = False + if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): + # deepspeed and SageMaker Model Parallel manage their own half precision + if args.half_precision_backend == "cuda_amp": + self.use_cuda_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + # bf16 does not need grad scaling + self.do_grad_scaling = self.amp_dtype == torch.float16 + if self.do_grad_scaling: + if self.sharded_ddp is not None: + self.scaler = ShardedGradScaler() + elif self.fsdp is not None: + from torch.distributed.fsdp.sharded_grad_scaler import ( + ShardedGradScaler as FSDPShardedGradScaler, + ) + + self.scaler = FSDPShardedGradScaler() + elif is_torch_tpu_available(): + from torch_xla.amp import GradScaler + + self.scaler = GradScaler() + else: + self.scaler = torch.cuda.amp.GradScaler() + elif args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + else: + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to" + " https://www.github.com/nvidia/apex." + ) + self.use_apex = True + + # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. + if ( + is_sagemaker_mp_enabled() + and self.use_cuda_amp + and args.max_grad_norm is not None + and args.max_grad_norm > 0 + ): + raise ValueError( + "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " + "along 'max_grad_norm': 0 in your hyperparameters." + ) + + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + + self.control = TrainerControl() + # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then + # returned to 0 every time flos need to be logged + self.current_flos = 0 + self.hp_search_backend = None + self.use_tune_checkpoints = False + default_label_names = find_labels(self.model.__class__) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(self.model.__class__) + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + # Internal variables to keep track of the original batch size + self._train_batch_size = args.train_batch_size + + # very last + self._memory_tracker.stop_and_update_metrics() + + # torch.compile + if args.torch_compile and not is_torch_compile_available(): + raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`~transformer.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + + def _move_model_to_device(self, model, device): + model = model.to(device) + # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. + if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): + model.tie_weights() + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if description is None else f"in the {description} set" + logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " + " you can safely ignore this message." + ) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def _get_collator_with_removed_columns( + self, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + if not self.args.remove_unused_columns: + return data_collator + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + logger=logger, + description=description, + model_name=self.model.__class__.__name__, + ) + return remove_columns_collator + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + generator = None + if self.args.world_size <= 1: + generator = torch.Generator() + # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with + # `args.seed`) if data_seed isn't provided. + # Further on in this method, we default to `args.seed` instead. + if self.args.data_seed is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.args.data_seed + generator.manual_seed(seed) + + seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed + + # Build the sampler. + if self.args.group_by_length: + if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): + lengths = ( + self.train_dataset[self.args.length_column_name] + if self.args.length_column_name in self.train_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + if self.args.world_size <= 1: + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + generator=generator, + ) + else: + return DistributedLengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + lengths=lengths, + model_input_name=model_input_name, + seed=seed, + ) + + else: + if self.args.world_size <= 1: + return RandomSampler(self.train_dataset, generator=generator) + elif ( + self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] + and not self.args.dataloader_drop_last + ): + # Use a loop for TPUs when drop_last is False to have all batches have the same size. + return DistributedSamplerWithLoop( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=seed, + ) + else: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=seed, + ) + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + if isinstance(train_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=self._train_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + train_sampler = self._get_train_sampler() + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) + + def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + # Deprecated code + if self.args.use_legacy_prediction_loop: + if is_torch_tpu_available(): + return SequentialDistributedSampler( + eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) + elif self.args.local_rank != -1: + return SequentialDistributedSampler(eval_dataset) + else: + return SequentialSampler(eval_dataset) + + if self.args.world_size <= 1: + return SequentialSampler(eval_dataset) + else: + return ShardSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + if isinstance(eval_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + eval_dataset = IterableDatasetShard( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + eval_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + + return DataLoader( + eval_dataset, + sampler=eval_sampler, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + data_collator = self.data_collator + + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + test_dataset = self._remove_unused_columns(test_dataset, description="test") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="test") + + if isinstance(test_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + test_dataset = IterableDatasetShard( + test_dataset, + batch_size=self.args.eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + test_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + test_sampler = self._get_eval_sampler(test_dataset) + + # We use the same batch_size as for eval. + return DataLoader( + test_dataset, + sampler=test_sampler, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_optimizer() + if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: + # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer + optimizer = self.optimizer.optimizer + else: + optimizer = self.optimizer + self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + print(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + print(f"skipped: {skipped/2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + @staticmethod + def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`transformers.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + if args.optim == OptimizerNames.ADAFACTOR: + optimizer_cls = Adafactor + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim == OptimizerNames.ADAMW_HF: + from transformers.optimization import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: + from torch.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: + optimizer_kwargs.update({"fused": True}) + elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: + try: + from torch_xla.amp.syncfree import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") + elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: + try: + from apex.optimizers import FusedAdam + + optimizer_cls = FusedAdam + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim == OptimizerNames.ADAMW_BNB: + try: + from bitsandbytes.optim import Adam8bit + + optimizer_cls = Adam8bit + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") + elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: + try: + from torchdistx.optimizers import AnyPrecisionAdamW + + optimizer_cls = AnyPrecisionAdamW + optimizer_kwargs.update(adam_kwargs) + + # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. + optimizer_kwargs.update( + { + "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), + "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), + "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), + "compensation_buffer_dtype": getattr( + torch, optim_args.get("compensation_buffer_dtype", "bfloat16") + ), + } + ) + except ImportError: + raise ValueError("Please install https://github.com/pytorch/torchdistx") + elif args.optim == OptimizerNames.SGD: + optimizer_cls = torch.optim.SGD + elif args.optim == OptimizerNames.ADAGRAD: + optimizer_cls = torch.optim.Adagrad + else: + raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") + return optimizer_cls, optimizer_kwargs + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + optimizer=self.optimizer if optimizer is None else optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + return self.lr_scheduler + + def num_examples(self, dataloader: DataLoader) -> int: + """ + Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When + dataloader.dataset does not exist or has no length, estimates as best it can + """ + try: + dataset = dataloader.dataset + # Special case for IterableDatasetShard, we need to dig deeper + if isinstance(dataset, IterableDatasetShard): + return len(dataloader.dataset.dataset) + return len(dataloader.dataset) + except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader + return len(dataloader) * self.args.per_device_train_batch_size + + def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): + """HP search setup code""" + self._trial = trial + + if self.hp_search_backend is None or trial is None: + return + if self.hp_search_backend == HPSearchBackend.OPTUNA: + params = self.hp_space(trial) + elif self.hp_search_backend == HPSearchBackend.RAY: + params = trial + params.pop("wandb", None) + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} + elif self.hp_search_backend == HPSearchBackend.WANDB: + params = trial + + for key, value in params.items(): + if not hasattr(self.args, key): + logger.warning( + f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" + " `TrainingArguments`." + ) + continue + old_attr = getattr(self.args, key, None) + # Casting value to the proper type + if old_attr is not None: + value = type(old_attr)(value) + setattr(self.args, key, value) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + logger.info(f"Trial: {trial.params}") + if self.hp_search_backend == HPSearchBackend.SIGOPT: + logger.info(f"SigOpt Assignments: {trial.assignments}") + if self.hp_search_backend == HPSearchBackend.WANDB: + logger.info(f"W&B Sweep parameters: {trial}") + if self.args.deepspeed: + # Rebuild the deepspeed config to reflect the updated training parameters + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) + self.args.hf_deepspeed_config.trainer_config_process(self.args) + + def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): + if self.hp_search_backend is None or trial is None: + return + self.objective = self.compute_objective(metrics.copy()) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + import optuna + + trial.report(self.objective, step) + if trial.should_prune(): + self.callback_handler.on_train_end(self.args, self.state, self.control) + raise optuna.TrialPruned() + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + if self.control.should_save: + self._tune_save_checkpoint() + tune.report(objective=self.objective, **metrics) + + def _tune_save_checkpoint(self): + from ray import tune + + if not self.use_tune_checkpoints: + return + with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: + output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") + self.save_model(output_dir, _internal_call=True) + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + + def call_model_init(self, trial=None): + model_init_argcount = number_of_arguments(self.model_init) + if model_init_argcount == 0: + model = self.model_init() + elif model_init_argcount == 1: + model = self.model_init(trial) + else: + raise RuntimeError("model_init should have 0 or 1 argument.") + + if model is None: + raise RuntimeError("model_init should not return None.") + + return model + + def torch_jit_model_eval(self, model, dataloader, training=False): + if not training: + if dataloader is None: + logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") + return model + example_batch = next(iter(dataloader)) + example_batch = self._prepare_inputs(example_batch) + try: + jit_model = model.eval() + with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): + if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): + if isinstance(example_batch, dict): + jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) + else: + jit_model = torch.jit.trace( + jit_model, + example_kwarg_inputs={key: example_batch[key] for key in example_batch}, + strict=False, + ) + else: + jit_inputs = [] + for key in example_batch: + example_tensor = torch.ones_like(example_batch[key]) + jit_inputs.append(example_tensor) + jit_inputs = tuple(jit_inputs) + jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) + jit_model = torch.jit.freeze(jit_model) + with torch.no_grad(): + jit_model(**example_batch) + jit_model(**example_batch) + model = jit_model + self.use_cpu_amp = False + self.use_cuda_amp = False + except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: + logger.warning(f"failed to use PyTorch jit mode due to: {e}.") + + return model + + def ipex_optimize_model(self, model, training=False, dtype=torch.float32): + if not is_ipex_available(): + raise ImportError( + "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" + " to https://github.com/intel/intel-extension-for-pytorch." + ) + + import intel_extension_for_pytorch as ipex + + if not training: + model.eval() + dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype + # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings + model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) + else: + if not model.training: + model.train() + model, self.optimizer = ipex.optimize( + model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" + ) + + return model + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.torch_compile: + model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) + + if self.args.use_ipex: + dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 + model = self.ipex_optimize_model(model, training, dtype=dtype) + + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + + # already initialized its own DDP and AMP + if self.deepspeed: + return self.deepspeed + + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if unwrap_model(model) is not model: + return model + + # Mixed precision training with apex (torch < 1.6) + if self.use_apex and training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) + if self.args.n_gpu > 1: + model = nn.DataParallel(model) + + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + if self.sharded_ddp is not None: + # Sharded DDP! + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + model = ShardedDDP(model, self.optimizer) + else: + mixed_precision = self.args.fp16 or self.args.bf16 + cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp + zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 + # XXX: Breaking the self.model convention but I see no way around it for now. + if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: + model = auto_wrap(model) + self.model = model = FullyShardedDDP( + model, + mixed_precision=mixed_precision, + reshard_after_forward=zero_3, + cpu_offload=cpu_offload, + ).to(self.args.device) + # Distributed training using PyTorch FSDP + elif self.fsdp is not None: + if not self.args.fsdp_config["xla"]: + # PyTorch FSDP! + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + + if FSDPOption.OFFLOAD in self.args.fsdp: + cpu_offload = CPUOffload(offload_params=True) + else: + cpu_offload = CPUOffload(offload_params=False) + + auto_wrap_policy = None + + if FSDPOption.AUTO_WRAP in self.args.fsdp: + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + mixed_precision_policy = None + dtype = None + if self.args.fp16: + dtype = torch.float16 + elif self.args.bf16: + dtype = torch.bfloat16 + if dtype is not None: + mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) + if type(model) != FSDP: + # XXX: Breaking the self.model convention but I see no way around it for now. + self.model = model = FSDP( + model, + sharding_strategy=self.fsdp, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + mixed_precision=mixed_precision_policy, + device_id=self.args.device, + backward_prefetch=self.backward_prefetch, + forward_prefetch=self.forword_prefetch, + limit_all_gathers=self.limit_all_gathers, + ) + else: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + return FSDP(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step + elif is_sagemaker_dp_enabled(): + model = nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] + ) + elif self.args.local_rank != -1: + kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + if is_torch_neuroncore_available(): + return model + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, + output_device=self.args.local_rank if self.args._n_gpu != 0 else None, + **kwargs, + ) + + return model + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs: + Additional keyword arguments used to hide deprecated arguments + """ + if resume_from_checkpoint is False: + resume_from_checkpoint = None + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None: + self._load_from_checkpoint(resume_from_checkpoint) + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self._train_batch_size = batch_size + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torch.distributed.launch)." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = ( + self.sharded_ddp is not None + and self.sharded_ddp != ShardedDDPOption.SIMPLE + or is_sagemaker_mp_enabled() + or self.fsdp is not None + ) + if args.deepspeed: + deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + elif not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState() + self.state.is_hyper_param_search = trial is not None + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + model = self._wrap_model(self.model_wrapped) + + if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + self._load_from_checkpoint(resume_from_checkpoint, model) + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + logger.info( + f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" + ) + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + if skip_first_batches is None: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," + " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" + " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" + " training on data already seen by your model." + ) + else: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: + steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) + steps_trained_progress_bar.set_description("Skipping the first batches") + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not args.ignore_data_skip: + for epoch in range(epochs_trained): + is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( + train_dataloader.sampler, RandomSampler + ) + if is_torch_less_than_1_11 or not is_random_sampler: + # We just need to begin an iteration to create the randomization of the sampler. + # That was before PyTorch 1.11 however... + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + _ = list(train_dataloader.sampler) + + total_batched_samples = 0 + for epoch in range(epochs_trained, num_train_epochs): + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch) + elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): + train_dataloader.dataset.set_epoch(epoch) + + if is_torch_tpu_available(): + parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) + epoch_iterator = parallel_loader + else: + epoch_iterator = train_dataloader + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if skip_first_batches is not None and steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + if ( + (total_batched_samples % args.gradient_accumulation_steps != 0) + and args.local_rank != -1 + and args._no_sync_in_gradient_accumulation + ): + # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. + with model.no_sync(): + tr_loss_step = self.training_step(model, inputs) + else: + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps + if self.deepspeed: + self.deepspeed.step() + + if total_batched_samples % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps + and (step + 1) == steps_in_epoch + ): + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + # deepspeed does its own clipping + + if self.do_grad_scaling: + # Reduce gradients first for XLA + if is_torch_tpu_available(): + gradients = xm._fetch_gradients(self.optimizer) + xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) + # AMP: gradients need unscaling + self.scaler.unscale_(self.optimizer) + + if is_sagemaker_mp_enabled() and args.fp16: + self.optimizer.clip_master_grads(args.max_grad_norm) + elif hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(args.max_grad_norm) + else: + # Revert to normal clipping otherwise, handling Apex or full precision + nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + args.max_grad_norm, + ) + + # Optimizer step + optimizer_was_run = True + if self.deepspeed: + pass # called outside the loop + elif is_torch_tpu_available(): + if self.do_grad_scaling: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + xm.optimizer_step(self.optimizer) + elif self.do_grad_scaling: + scale_before = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler.get_scale() + optimizer_was_run = scale_before <= scale_after + else: + self.optimizer.step() + + if optimizer_was_run and not self.deepspeed: + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + if step < 0: + logger.warning( + "There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_tpu_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sur the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.local_rank != -1: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if checkpoint != self.state.best_model_checkpoint: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def _get_output_dir(self, trial): + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + run_id = tune.get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + elif self.hp_search_backend == HPSearchBackend.WANDB: + import wandb + + run_id = wandb.run.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + return run_dir + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + if model is None: + model = self.model + + if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( + os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + ): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint}.") + + if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): + config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warning( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + + if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + # If the model is on the GPU, it still works! + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if hasattr(self.args, "fp16") and self.args.fp16 is True: + logger.warning( + "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." + ) + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # Required for smp to not auto-translate state_dict from hf to smp (is already smp). + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + # release memory + del state_dict + else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + # release memory + del state_dict + self._issue_warnings_after_load(load_result) + else: + # We load the sharded checkpoint + load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled()) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if os.path.exists(best_model_path): + if self.deepspeed: + if self.model_wrapped is not None: + # this removes the pre-hooks from the previous engine + self.model_wrapped.destroy() + self.model_wrapped = None + + # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping + deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + self, + num_training_steps=self.args.max_steps, + resume_from_checkpoint=self.state.best_model_checkpoint, + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + else: + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + state_dict = torch.load(best_model_path, map_location="cpu") + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + load_result = load_sharded_checkpoint( + model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _issue_warnings_after_load(self, load_result): + if len(load_result.missing_keys) != 0: + if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( + self.model._keys_to_ignore_on_save + ): + self.model.tie_weights() + else: + logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") + if len(load_result.unexpected_keys) != 0: + logger.warning( + f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log: + if is_torch_tpu_available(): + xm.mark_step() + + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + if isinstance(self.eval_dataset, dict): + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + if self.args.world_size > 1: + process_index = self.args.process_index + rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") + if not os.path.isfile(rng_file): + logger.info( + f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if torch.cuda.is_available(): + if self.args.local_rank != -1: + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + else: + try: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_tpu_available(): + xm.set_rng_state(checkpoint_rng_state["xla"]) + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + if self.deepspeed: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + self.deepspeed.save_checkpoint(output_dir) + + # Save optimizer and scheduler + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer.consolidate_state_dict() + + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + if self.args.should_save: + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + elif self.args.should_save and not self.deepspeed: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.local_rank == -1: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + + if self.args.world_size <= 1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if self.deepspeed: + # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + return + + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + if is_torch_tpu_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + with warnings.catch_warnings(record=True) as caught_warnings: + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): + # Optimizer checkpoint was saved with smp >= 1.10 + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + else: + # Optimizer checkpoint was saved with smp < 1.10 + def opt_load_hook(mod, opt): + if IS_SAGEMAKER_MP_POST_1_10: + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) + ) + else: + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) + + def hyperparameter_search( + self, + hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, + compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, + n_trials: int = 20, + direction: str = "minimize", + backend: Optional[Union["str", HPSearchBackend]] = None, + hp_name: Optional[Callable[["optuna.Trial"], str]] = None, + **kwargs, + ) -> BestRun: + """ + Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined + by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, + the sum of all metrics otherwise. + + <Tip warning={true}> + + To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to + reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to + subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom + optimizer/scheduler. + + </Tip> + + Args: + hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): + A function that defines the hyperparameter search space. Will default to + [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or + [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. + compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): + A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` + method. Will default to [`~trainer_utils.default_compute_objective`]. + n_trials (`int`, *optional*, defaults to 100): + The number of trial runs to test. + direction (`str`, *optional*, defaults to `"minimize"`): + Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick + `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. + backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): + The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending + on which one is installed. If all are installed, will default to optuna. + hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): + A function that defines the trial/run name. Will default to None. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more + information see: + + - the documentation of + [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) + - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) + - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) + + Returns: + [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in + `run_summary` attribute for Ray backend. + """ + if backend is None: + backend = default_hp_search_backend() + if backend is None: + raise RuntimeError( + "At least one of optuna or ray should be installed. " + "To install optuna run `pip install optuna`. " + "To install ray run `pip install ray[tune]`. " + "To install sigopt run `pip install sigopt`." + ) + backend = HPSearchBackend(backend) + if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): + raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") + if backend == HPSearchBackend.RAY and not is_ray_tune_available(): + raise RuntimeError( + "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." + ) + if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): + raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") + if backend == HPSearchBackend.WANDB and not is_wandb_available(): + raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") + self.hp_search_backend = backend + if self.model_init is None: + raise RuntimeError( + "To use hyperparameter search, you need to pass your model through a model_init function." + ) + + self.hp_space = default_hp_space[backend] if hp_space is None else hp_space + self.hp_name = hp_name + self.compute_objective = default_compute_objective if compute_objective is None else compute_objective + + backend_dict = { + HPSearchBackend.OPTUNA: run_hp_search_optuna, + HPSearchBackend.RAY: run_hp_search_ray, + HPSearchBackend.SIGOPT: run_hp_search_sigopt, + HPSearchBackend.WANDB: run_hp_search_wandb, + } + best_run = backend_dict[backend](self, n_trials, direction, **kwargs) + + self.hp_search_backend = None + return best_run + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + if self.state.epoch is not None: + logs["epoch"] = round(self.state.epoch, 2) + + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)({k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": self.args.device} + if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) + return data.to(**kwargs) + return data + + def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and + handling potential state. + """ + inputs = self._prepare_input(inputs) + if len(inputs) == 0: + raise ValueError( + "The batch received was empty, your model won't be able to train on it. Double-check that your " + f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." + ) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs + + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + return self.autocast_smart_context_manager() + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.use_cuda_amp or self.use_cpu_amp: + if is_torch_greater_or_equal_than_1_10: + ctx_manager = ( + torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + if self.use_cpu_amp + else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + ) + else: + ctx_manager = torch.cuda.amp.autocast() + else: + ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() + + return ctx_manager + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.args.gradient_accumulation_steps + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def is_local_process_zero(self) -> bool: + """ + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several + machines) main process. + """ + return self.args.local_process_index == 0 + + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + """ + # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global + # process index. + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.args.process_index == 0 + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if is_torch_tpu_available(): + self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model_wrapped.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + if IS_SAGEMAKER_MP_POST_1_10: + # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 + Path(os.path.join(output_dir, "user_content.pt")).touch() + elif ( + ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp + or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + or self.fsdp is not None + ): + state_dict = self.model.state_dict() + + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + elif self.deepspeed: + # this takes care of everything as long as we aren't under zero3 + if self.args.should_save: + self._save(output_dir) + + if is_deepspeed_zero3_enabled(): + # It's too complicated to try to override different places where the weights dump gets + # saved, so since under zero3 the file is bogus, simply delete it. The user should + # either user deepspeed checkpoint to resume or to recover full weights use + # zero_to_fp32.py stored in the checkpoint. + if self.args.should_save: + file = os.path.join(output_dir, WEIGHTS_NAME) + if os.path.isfile(file): + # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") + os.remove(file) + + # now save the real model if stage3_gather_16bit_weights_on_model_save=True + # if false it will not be saved. + # This must be called on all ranks + if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME): + logger.warning( + "deepspeed.save_16bit_model didn't save the model, since" + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self.deepspeed.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") + + if xm.is_master_ordinal(): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + xm.rendezvous("saving_checkpoint") + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + unwrap_model(self.model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=self.model.state_dict(), + save_function=xm.save, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = self.model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + if self.tokenizer is not None and self.args.should_save: + self.tokenizer.save_pretrained(output_dir) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + if state_dict is None: + state_dict = self.model.state_dict() + unwrap_model(self.model).save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if state_dict is None: + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + if self.save_prefixencoder: + print("Saving PrefixEncoder") + state_dict = self.model.state_dict() + filtered_state_dict = {} + for k, v in self.model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k] + self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + print("Saving the whole model") + self.model.save_pretrained(output_dir, state_dict=state_dict) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def store_flos(self): + # Storing the number of floating-point operations that went into the model + if self.args.local_rank != -1: + self.state.total_flos += ( + distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() + ) + self.current_flos = 0 + else: + self.state.total_flos += self.current_flos + self.current_flos = 0 + + def _sorted_checkpoints( + self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False + ) -> List[str]: + ordering_and_checkpoint_path = [] + + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] + + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + # Make sure we don't delete the best model. + if self.state.best_model_checkpoint is not None: + best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] + return checkpoints_sorted + + def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return + + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def predict( + self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + <Tip> + + If your predictions or labels have different sequence length (for instance because you're doing dynamic padding + in a token classification task) the predictions will be padded (on the right) to allow for concatenation into + one array. The padding index is -100. + + </Tip> + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + test_dataloader = self.get_test_dataloader(test_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"***** Running {description} *****") + if has_length(dataloader): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + + if args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + inputs_host = None + + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + all_inputs = None + # Will be useful when we have an iterable dataset so don't know its length. + + observed_num_examples = 0 + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + if is_torch_tpu_available(): + xm.mark_step() + + # Update containers on host + if loss is not None: + losses = self._nested_gather(loss.repeat(batch_size)) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if labels is not None: + labels = self._pad_across_processes(labels) + labels = self._nested_gather(labels) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_decode = self._pad_across_processes(inputs_decode) + inputs_decode = self._nested_gather(inputs_decode) + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + if logits is not None: + logits = self._pad_across_processes(logits) + logits = self._nested_gather(logits) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode + if all_inputs is None + else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = ( + labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + ) + + # Set back to None to begin a new accumulation + losses_host, preds_host, inputs_host, labels_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + if num_samples == 0 and observed_num_examples > 0: + num_samples = observed_num_examples + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + if all_inputs is not None: + all_inputs = nested_truncate(all_inputs, num_samples) + + # Metrics! + if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if all_losses is not None: + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + if name is None: + name = "nested_gather" + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.local_rank != -1: + tensors = distributed_concat(tensors) + return tensors + + # Copied from Accelerate. + def _pad_across_processes(self, tensor, pad_index=-100): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so + they can safely be gathered. + """ + if isinstance(tensor, (list, tuple)): + return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + + if len(tensor.shape) < 2: + return tensor + # Gather all sizes + size = torch.tensor(tensor.shape, device=tensor.device)[None] + sizes = self._nested_gather(size).cpu() + + max_size = max(s[1] for s in sizes) + # When extracting XLA graphs for compilation, max_size is 0, + # so use inequality to avoid errors. + if tensor.shape[1] >= max_size: + return tensor + + # Then pad to the maximum size + old_size = tensor.shape + new_size = list(old_size) + new_size[1] = max_size + new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + new_tensor[:, : old_size[1]] = tensor + return new_tensor + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point + operations for every backward + forward pass. If using another model, either implement such a method in the + model or subclass and override this method. + + Args: + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + `int`: The number of floating-point operations. + """ + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) + else: + return 0 + + def init_git_repo(self, at_init: bool = False): + """ + Initializes a git repo in `self.args.hub_model_id`. + + Args: + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is + `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped + out. + """ + if not self.is_world_process_zero(): + return + if self.args.hub_model_id is None: + repo_name = Path(self.args.output_dir).absolute().name + else: + repo_name = self.args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) + + # Make sure the repo exists. + create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) + try: + self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + except EnvironmentError: + if self.args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(self.args.output_dir) + self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + else: + raise + + self.repo.git_pull() + + # By default, ignore the checkpoint folders + if ( + not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) + and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS + ): + with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + # Add "*.sagemaker" to .gitignore if using SageMaker + if os.environ.get("SM_TRAINING_ENV"): + self._add_sm_patterns_to_gitignore() + + self.push_in_progress = None + + def create_model_card( + self, + language: Optional[str] = None, + license: Optional[str] = None, + tags: Union[str, List[str], None] = None, + model_name: Optional[str] = None, + finetuned_from: Optional[str] = None, + tasks: Union[str, List[str], None] = None, + dataset_tags: Union[str, List[str], None] = None, + dataset: Union[str, List[str], None] = None, + dataset_args: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `List[str]`, *optional*): + Some tags to be included in the metadata of the model card. + model_name (`str`, *optional*): + The name of the model. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `List[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `List[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `List[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `List[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + if not self.is_world_process_zero(): + return + + training_summary = TrainingSummary.from_trainer( + self, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: + f.write(model_card) + + def _push_from_checkpoint(self, checkpoint_folder): + # Only push from one node. + if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: + return + # If we haven't finished the last push, we don't do this one. + if self.push_in_progress is not None and not self.push_in_progress.is_done: + return + + output_dir = self.args.output_dir + # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder + modeling_files = [CONFIG_NAME, WEIGHTS_NAME] + for modeling_file in modeling_files: + if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): + shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) + # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + # Same for the training arguments + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + try: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Temporarily move the checkpoint just saved for the push + tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") + # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a + # subfolder. + if os.path.isdir(tmp_checkpoint): + shutil.rmtree(tmp_checkpoint) + shutil.move(checkpoint_folder, tmp_checkpoint) + + if self.args.save_strategy == IntervalStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" + _, self.push_in_progress = self.repo.push_to_hub( + commit_message=commit_message, blocking=False, auto_lfs_prune=True + ) + finally: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Move back the checkpoint to its place + shutil.move(tmp_checkpoint, checkpoint_folder) + + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + + Parameters: + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`~Trainer.create_model_card`]. + + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of + the commit and an object to track the progress of the commit if `blocking=True` + """ + # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but + # it might fail. + if not hasattr(self, "repo"): + self.init_git_repo() + + model_name = kwargs.pop("model_name", None) + if model_name is None and self.args.should_save: + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + + # Needs to be executed on all processes for TPU training, but will only save on the processed determined by + # self.args.should_save. + self.save_model(_internal_call=True) + + # Only push from one node. + if not self.is_world_process_zero(): + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: + self.push_in_progress._process.kill() + self.push_in_progress = None + + git_head_commit_url = self.repo.push_to_hub( + commit_message=commit_message, blocking=blocking, auto_lfs_prune=True + ) + # push separately the model card to be independant from the rest of the model + if self.args.should_save: + self.create_model_card(model_name=model_name, **kwargs) + try: + self.repo.push_to_hub( + commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True + ) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + # + # Deprecated code + # + + def prediction_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + if not has_length(dataloader): + raise ValueError("dataloader must implement a working __len__") + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since + # for example the Z3-optimizer is a must for zero3 to work even for inference - what we + # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer + deepspeed_engine.optimizer.optimizer = None + deepspeed_engine.lr_scheduler = None + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = dataloader.batch_size + num_examples = self.num_examples(dataloader) + logger.info(f"***** Running {description} *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None + preds_host: Union[torch.Tensor, List[torch.Tensor]] = None + labels_host: Union[torch.Tensor, List[torch.Tensor]] = None + inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + + world_size = max(1, args.world_size) + + eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + if not prediction_loss_only: + # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass + # a batch size to the sampler) + make_multiple_of = None + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): + make_multiple_of = dataloader.sampler.batch_size + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + + model.eval() + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + + if args.past_index >= 0: + self._past = None + + self.callback_handler.eval_dataloader = dataloader + + for step, inputs in enumerate(dataloader): + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + if loss is not None: + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if logits is not None: + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + # Set back to None to begin a new accumulation + losses_host, preds_host, labels_host, inputs_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + eval_loss = eval_losses_gatherer.finalize() + preds = preds_gatherer.finalize() if not prediction_loss_only else None + label_ids = labels_gatherer.finalize() if not prediction_loss_only else None + inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None + + if self.compute_metrics is not None and preds is not None and label_ids is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if eval_loss is not None: + metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) + + def _gather_and_numpify(self, tensors, name): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.local_rank != -1: + tensors = distributed_concat(tensors) + + return nested_numpify(tensors) + + def _add_sm_patterns_to_gitignore(self) -> None: + """Add SageMaker Checkpointing patterns to .gitignore file.""" + # Make sure we only do this on the main process + if not self.is_world_process_zero(): + return + + patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] + + # Get current .gitignore content + if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): + with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: + current_content = f.read() + else: + current_content = "" + + # Add the patterns to .gitignore + content = current_content + for pattern in patterns: + if pattern not in content: + if content.endswith("\n"): + content += pattern + else: + content += f"\n{pattern}" + + # Write the .gitignore file if it has changed + if content != current_content: + with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: + logger.debug(f"Writing .gitignore file. Content: {content}") + f.write(content) + + self.repo.git_add(".gitignore") + + # avoid race condition with git status + time.sleep(0.5) + + if not self.repo.is_repo_clean(): + self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") + self.repo.git_push() diff --git a/ptuning/trainer_seq2seq.py b/ptuning/trainer_seq2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..19d5cf12a274944a3ea3ce689414eab72636e0bd --- /dev/null +++ b/ptuning/trainer_seq2seq.py @@ -0,0 +1,247 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import Dataset + +from transformers.deepspeed import is_deepspeed_zero3_enabled +from trainer import Trainer +from transformers.trainer_utils import PredictionOutput +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Seq2SeqTrainer(Trainer): + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + **gen_kwargs + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + + gen_kwargs = gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + + return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def predict( + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "test", + **gen_kwargs + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + <Tip> + + If your predictions or labels have different sequence lengths (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + + </Tip> + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + + gen_kwargs = gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + + + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # XXX: adapt synced_gpus for fairscale as well + gen_kwargs = self._gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.model.config.max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams + ) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs["synced_gpus"] = ( + gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus + ) + + if "attention_mask" in inputs: + gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) + if "position_ids" in inputs: + gen_kwargs["position_ids"] = inputs.get("position_ids", None) + if "global_attention_mask" in inputs: + gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) + + # prepare generation inputs + # some encoder-decoder models can have varying encoder's and thus + # varying model input names + if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: + generation_inputs = inputs[self.model.encoder.main_input_name] + else: + generation_inputs = inputs[self.model.main_input_name] + + gen_kwargs["input_ids"] = generation_inputs + generated_tokens = self.model.generate(**gen_kwargs) + generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] + + # in case the batch is shorter than max length, the output should be padded + if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) + + loss = None + + if self.args.prediction_loss_only: + return (loss, None, None) + + if has_labels: + labels = inputs["labels"] + if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) + else: + labels = None + + return (loss, generated_tokens, labels) + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + ) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/ptuning/web_demo.py b/ptuning/web_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..43d0c826e7190b11b6fe50e1a6243e9fceead2d2 --- /dev/null +++ b/ptuning/web_demo.py @@ -0,0 +1,166 @@ +import os, sys + +import gradio as gr +import mdtex2html + +import torch +import transformers +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainingArguments, + set_seed, +) + +from arguments import ModelArguments, DataTrainingArguments + + +model = None +tokenizer = None + +"""Override Chatbot.postprocess""" + + +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'<pre><code class="language-{items[-1]}">' + else: + lines[i] = f'<br></code></pre>' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "<br>"+line + text = "".join(lines) + return text + + +def predict(input, chatbot, max_length, top_p, temperature, history): + chatbot.append((parse_text(input), "")) + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + chatbot[-1] = (parse_text(input), parse_text(response)) + + yield chatbot, history + + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return [], [] + + +with gr.Blocks() as demo: + gr.HTML("""<h1 align="center">ChatGLM</h1>""") + + chatbot = gr.Chatbot() + with gr.Row(): + with gr.Column(scale=4): + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + + history = gr.State([]) + + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + + + +def main(): + global model, tokenizer + + parser = HfArgumentParser(( + ModelArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] + else: + model_args = parser.parse_args_into_dataclasses()[0] + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=True) + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, trust_remote_code=True) + + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + + if model_args.ptuning_checkpoint is not None: + print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}") + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + else: + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + + if model_args.quantization_bit is not None: + print(f"Quantized to {model_args.quantization_bit} bit") + model = model.quantize(model_args.quantization_bit) + + if model_args.pre_seq_len is not None: + # P-tuning v2 + model = model.half().cuda() + model.transformer.prefix_encoder.float().cuda() + + model = model.eval() + demo.queue().launch(share=False, inbrowser=True) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ptuning/web_demo.sh b/ptuning/web_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..87bf9e9fc24b716b5d48bcfc7904a1e05d466318 --- /dev/null +++ b/ptuning/web_demo.sh @@ -0,0 +1,7 @@ +PRE_SEQ_LEN=128 + +CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \ + --model_name_or_path THUDM/chatglm-6b \ + --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \ + --pre_seq_len $PRE_SEQ_LEN + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb8d79f7519d03793281b87b132d43bd17b85784 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +protobuf +transformers==4.27.1 +cpm_kernels +torch>=1.10 +gradio +mdtex2html +sentencepiece +accelerate \ No newline at end of file diff --git a/resources/WECHAT.md b/resources/WECHAT.md new file mode 100644 index 0000000000000000000000000000000000000000..c9ee867ead5d818a0b4e2ba46103a6454537d143 --- /dev/null +++ b/resources/WECHAT.md @@ -0,0 +1,7 @@ +<div align="center"> +<img src=wechat.jpg width="60%"/> + +<p> 扫码关注公众号,加入「ChatGLM交流群」 </p> +<p> Scan the QR code to follow the official account and join the "ChatGLM Discussion Group" </p> +</div> + diff --git a/resources/cli-demo.png b/resources/cli-demo.png new file mode 100644 index 0000000000000000000000000000000000000000..3d489b5bd6d7e9d6b1db198307b8204f6fc2f80e Binary files /dev/null and b/resources/cli-demo.png differ diff --git a/resources/english-q1-new.png b/resources/english-q1-new.png new file mode 100644 index 0000000000000000000000000000000000000000..798b9e0e12783d47620feb4708c5c3e5957b6d77 Binary files /dev/null and b/resources/english-q1-new.png differ diff --git a/resources/english-q1-old.png b/resources/english-q1-old.png new file mode 100644 index 0000000000000000000000000000000000000000..81b4a706c6e394797d1a69c947790de2fdcbb017 Binary files /dev/null and b/resources/english-q1-old.png differ diff --git a/resources/english-q2-new.png b/resources/english-q2-new.png new file mode 100644 index 0000000000000000000000000000000000000000..36cd6e9057b4f0a4f66861726a2fad66c31cd91a Binary files /dev/null and b/resources/english-q2-new.png differ diff --git a/resources/english-q2-old.png b/resources/english-q2-old.png new file mode 100644 index 0000000000000000000000000000000000000000..fcffd513945491024cf7bc121d515ae835df930f Binary files /dev/null and b/resources/english-q2-old.png differ diff --git a/resources/english-q3-new.png b/resources/english-q3-new.png new file mode 100644 index 0000000000000000000000000000000000000000..52d712eb42f3879ef4d4720244e03421fa016605 Binary files /dev/null and b/resources/english-q3-new.png differ diff --git a/resources/english-q3-old.png b/resources/english-q3-old.png new file mode 100644 index 0000000000000000000000000000000000000000..7388119fbf375e8d38906da80b0076567966f6aa Binary files /dev/null and b/resources/english-q3-old.png differ diff --git a/resources/english-q4-new.png b/resources/english-q4-new.png new file mode 100644 index 0000000000000000000000000000000000000000..af3eabce8e20ae5cb8fad31f3fcdfa12ea49b51d Binary files /dev/null and b/resources/english-q4-new.png differ diff --git a/resources/english-q4-old.png b/resources/english-q4-old.png new file mode 100644 index 0000000000000000000000000000000000000000..56d2913fb7a3de560cf484da430457f1e1d4cfdc Binary files /dev/null and b/resources/english-q4-old.png differ diff --git a/resources/visualglm.png b/resources/visualglm.png new file mode 100644 index 0000000000000000000000000000000000000000..d94c9b70199aa3aa2dcd4a54ed35b6ae63ebfbb8 Binary files /dev/null and b/resources/visualglm.png differ diff --git a/resources/web-demo.gif b/resources/web-demo.gif new file mode 100644 index 0000000000000000000000000000000000000000..c775716f5a9608281dc3306022249afa121d5105 --- /dev/null +++ b/resources/web-demo.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba8ff042bbd879cbb4dd3795081b2e4e3713d3a4d2d5d7d61a027c389324cbbc +size 2284452 diff --git a/resources/web-demo.png b/resources/web-demo.png new file mode 100644 index 0000000000000000000000000000000000000000..7711fd698bdb15432e21085b365510e95d55fd88 Binary files /dev/null and b/resources/web-demo.png differ diff --git a/resources/webglm.jpg b/resources/webglm.jpg new file mode 100644 index 0000000000000000000000000000000000000000..199ea0e45cc77dd7e42d86f31165a82ee52d8e65 Binary files /dev/null and b/resources/webglm.jpg differ diff --git a/resources/wechat.jpg b/resources/wechat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ea08f2423d294a9c50efad144fdc1058b079fde Binary files /dev/null and b/resources/wechat.jpg differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45015f97b439bce90a813c67dfd304d3dc68cbf5 --- /dev/null +++ b/utils.py @@ -0,0 +1,54 @@ +import os +from typing import Dict, Tuple, Union, Optional + +from torch.nn import Module +from transformers import AutoModel + + +def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: + # transformer.word_embeddings 占用1层 + # transformer.final_layernorm 和 lm_head 占用1层 + # transformer.layers 占用 28 层 + # 总共30层分配到num_gpus张卡上 + num_trans_layers = 28 + per_gpu_layers = 30 / num_gpus + + # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError + # windows下 model.device 会被设置成 transformer.word_embeddings.device + # linux下 model.device 会被设置成 lm_head.device + # 在调用chat或者stream_chat时,input_ids会被放到model.device上 + # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError + # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 + device_map = {'transformer.word_embeddings': 0, + 'transformer.final_layernorm': 0, 'lm_head': 0} + + used = 2 + gpu_target = 0 + for i in range(num_trans_layers): + if used >= per_gpu_layers: + gpu_target += 1 + used = 0 + assert gpu_target < num_gpus + device_map[f'transformer.layers.{i}'] = gpu_target + used += 1 + + return device_map + + +def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, + device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: + if num_gpus < 2 and device_map is None: + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() + else: + from accelerate import dispatch_model + + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() + + if device_map is None: + device_map = auto_configure_device_map(num_gpus) + + model = dispatch_model(model, device_map=device_map) + + return model + + diff --git a/web_demo.py b/web_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..475dcfad555189a544e0565f7c2c07bc742a38e5 --- /dev/null +++ b/web_demo.py @@ -0,0 +1,101 @@ +from transformers import AutoModel, AutoTokenizer +import gradio as gr +import mdtex2html + +tokenizer = AutoTokenizer.from_pretrained("../chatglm", trust_remote_code=True) +model = AutoModel.from_pretrained("../chatglm", trust_remote_code=True).float() +model = model.eval() + +"""Override Chatbot.postprocess""" + + +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'<pre><code class="language-{items[-1]}">' + else: + lines[i] = f'<br></code></pre>' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "<br>"+line + text = "".join(lines) + return text + + +def predict(input, chatbot, max_length, top_p, temperature, history): + chatbot.append((parse_text(input), "")) + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + chatbot[-1] = (parse_text(input), parse_text(response)) + + yield chatbot, history + + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return [], [] + + +with gr.Blocks() as demo: + gr.HTML("""<h1 align="center">ChatGLM</h1>""") + + chatbot = gr.Chatbot() + with gr.Row(): + with gr.Column(scale=4): + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + + history = gr.State([]) + + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + +demo.queue().launch(share=True, inbrowser=True) diff --git a/web_demo2.py b/web_demo2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce976b3fb90508b772e4c63d91f679d88f2cf939 --- /dev/null +++ b/web_demo2.py @@ -0,0 +1,71 @@ +from transformers import AutoModel, AutoTokenizer +import streamlit as st +from streamlit_chat import message + + +st.set_page_config( + page_title="ChatGLM-6b 演示", + page_icon=":robot:" +) + + +@st.cache_resource +def get_model(): + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model = model.eval() + return tokenizer, model + + +MAX_TURNS = 20 +MAX_BOXES = MAX_TURNS * 2 + + +def predict(input, max_length, top_p, temperature, history=None): + tokenizer, model = get_model() + if history is None: + history = [] + + with container: + if len(history) > 0: + if len(history)>MAX_BOXES: + history = history[-MAX_TURNS:] + for i, (query, response) in enumerate(history): + message(query, avatar_style="big-smile", key=str(i) + "_user") + message(response, avatar_style="bottts", key=str(i)) + + message(input, avatar_style="big-smile", key=str(len(history)) + "_user") + st.write("AI正在回复:") + with st.empty(): + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + query, response = history[-1] + st.write(response) + + return history + + +container = st.container() + +# create a prompt text for the text generation +prompt_text = st.text_area(label="用户命令输入", + height = 100, + placeholder="请在这儿输入您的命令") + +max_length = st.sidebar.slider( + 'max_length', 0, 4096, 2048, step=1 +) +top_p = st.sidebar.slider( + 'top_p', 0.0, 1.0, 0.6, step=0.01 +) +temperature = st.sidebar.slider( + 'temperature', 0.0, 1.0, 0.95, step=0.01 +) + +if 'state' not in st.session_state: + st.session_state['state'] = [] + +if st.button("发送", key="predict"): + with st.spinner("AI正在思考,请稍等........"): + # text generation + st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) diff --git a/web_demo_old.py b/web_demo_old.py new file mode 100644 index 0000000000000000000000000000000000000000..88a6dc88aa3df7e3d4107f0e11a50d525b4dde05 --- /dev/null +++ b/web_demo_old.py @@ -0,0 +1,45 @@ +from transformers import AutoModel, AutoTokenizer +import gradio as gr + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +MAX_TURNS = 20 +MAX_BOXES = MAX_TURNS * 2 + + +def predict(input, max_length, top_p, temperature, history=None): + if history is None: + history = [] + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + updates = [] + for query, response in history: + updates.append(gr.update(visible=True, value="用户:" + query)) + updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) + if len(updates) < MAX_BOXES: + updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) + yield [history] + updates + + +with gr.Blocks() as demo: + state = gr.State([]) + text_boxes = [] + for i in range(MAX_BOXES): + if i % 2 == 0: + text_boxes.append(gr.Markdown(visible=False, label="提问:")) + else: + text_boxes.append(gr.Markdown(visible=False, label="回复:")) + + with gr.Row(): + with gr.Column(scale=4): + txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( + container=False) + with gr.Column(scale=1): + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + button = gr.Button("Generate") + button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) +demo.queue().launch(share=False, inbrowser=True) diff --git a/web_demo_vision.py b/web_demo_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..79f6b52eb5669b8c86e6a371c4aaedeb7b15ab9e --- /dev/null +++ b/web_demo_vision.py @@ -0,0 +1,120 @@ +from transformers import AutoModel, AutoTokenizer +import gradio as gr +import mdtex2html + +tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +"""Override Chatbot.postprocess""" + + +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'<pre><code class="language-{items[-1]}">' + else: + lines[i] = f'<br></code></pre>' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "<br>"+line + text = "".join(lines) + return text + + +def predict(input, image_path, chatbot, max_length, top_p, temperature, history): + if image_path is None: + return [(input, "图片为空!请重新上传图片并重试。")] + chatbot.append((parse_text(input), "")) + for response, history in model.stream_chat(tokenizer, image_path, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + chatbot[-1] = (parse_text(input), parse_text(response)) + + yield chatbot, history + + +def predict_new_image(image_path, chatbot, max_length, top_p, temperature): + input, history = "描述这张图片。", [] + chatbot.append((parse_text(input), "")) + for response, history in model.stream_chat(tokenizer, image_path, input, history, max_length=max_length, + top_p=top_p, + temperature=temperature): + chatbot[-1] = (parse_text(input), parse_text(response)) + + yield chatbot, history + + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return None, [], [] + + +with gr.Blocks() as demo: + gr.HTML("""<h1 align="center">VisualGLM</h1>""") + + image_path = gr.Image(type="filepath", label="Image Prompt", value=None) + chatbot = gr.Chatbot() + with gr.Row(): + with gr.Column(scale=4): + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.4, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True) + + history = gr.State([]) + + submitBtn.click(predict, [user_input, image_path, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) + + image_path.upload(predict_new_image, [image_path, chatbot, max_length, top_p, temperature], [chatbot, history], + show_progress=True) + image_path.clear(reset_state, outputs=[image_path, chatbot, history], show_progress=True) + + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[image_path, chatbot, history], show_progress=True) + +demo.queue().launch(share=False, inbrowser=True)