HaotianHu commited on
Commit
cd32184
·
1 Parent(s): b370fd5

extra setup so llamafactory can run

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. llama-factory/README.md +645 -0
  3. llama-factory/pyproject.toml +33 -0
  4. llama-factory/requirements.txt +21 -0
  5. llama-factory/setup.py +92 -0
  6. llama-factory/src/api.py +33 -0
  7. llama-factory/src/llamafactory/__init__.py +41 -0
  8. llama-factory/src/llamafactory/api/__init__.py +0 -0
  9. llama-factory/src/llamafactory/api/app.py +122 -0
  10. llama-factory/src/llamafactory/api/chat.py +237 -0
  11. llama-factory/src/llamafactory/api/common.py +34 -0
  12. llama-factory/src/llamafactory/api/protocol.py +153 -0
  13. llama-factory/src/llamafactory/chat/__init__.py +19 -0
  14. llama-factory/src/llamafactory/chat/base_engine.py +78 -0
  15. llama-factory/src/llamafactory/chat/chat_model.py +155 -0
  16. llama-factory/src/llamafactory/chat/hf_engine.py +343 -0
  17. llama-factory/src/llamafactory/chat/vllm_engine.py +242 -0
  18. llama-factory/src/llamafactory/cli.py +121 -0
  19. llama-factory/src/llamafactory/data/__init__.py +31 -0
  20. llama-factory/src/llamafactory/data/aligner.py +239 -0
  21. llama-factory/src/llamafactory/data/collator.py +155 -0
  22. llama-factory/src/llamafactory/data/data_utils.py +87 -0
  23. llama-factory/src/llamafactory/data/formatter.py +140 -0
  24. llama-factory/src/llamafactory/data/loader.py +276 -0
  25. llama-factory/src/llamafactory/data/parser.py +153 -0
  26. llama-factory/src/llamafactory/data/preprocess.py +110 -0
  27. llama-factory/src/llamafactory/data/processors/__init__.py +0 -0
  28. llama-factory/src/llamafactory/data/processors/feedback.py +143 -0
  29. llama-factory/src/llamafactory/data/processors/pairwise.py +139 -0
  30. llama-factory/src/llamafactory/data/processors/pretrain.py +54 -0
  31. llama-factory/src/llamafactory/data/processors/processor_utils.py +95 -0
  32. llama-factory/src/llamafactory/data/processors/supervised.py +202 -0
  33. llama-factory/src/llamafactory/data/processors/unsupervised.py +106 -0
  34. llama-factory/src/llamafactory/data/template.py +905 -0
  35. llama-factory/src/llamafactory/data/tool_utils.py +140 -0
  36. llama-factory/src/llamafactory/eval/__init__.py +0 -0
  37. llama-factory/src/llamafactory/eval/evaluator.py +154 -0
  38. llama-factory/src/llamafactory/eval/template.py +81 -0
  39. llama-factory/src/llamafactory/extras/__init__.py +0 -0
  40. llama-factory/src/llamafactory/extras/constants.py +1590 -0
  41. llama-factory/src/llamafactory/extras/env.py +75 -0
  42. llama-factory/src/llamafactory/extras/logging.py +82 -0
  43. llama-factory/src/llamafactory/extras/misc.py +228 -0
  44. llama-factory/src/llamafactory/extras/packages.py +88 -0
  45. llama-factory/src/llamafactory/extras/ploting.py +101 -0
  46. llama-factory/src/llamafactory/hparams/__init__.py +32 -0
  47. llama-factory/src/llamafactory/hparams/data_args.py +143 -0
  48. llama-factory/src/llamafactory/hparams/evaluation_args.py +62 -0
  49. llama-factory/src/llamafactory/hparams/finetuning_args.py +400 -0
  50. llama-factory/src/llamafactory/hparams/generating_args.py +74 -0
.gitignore CHANGED
@@ -150,3 +150,4 @@ dmypy.json
150
  /huggingface_tokenizers_cache
151
  /llama-factory/huggingface_tokenizers_cache
152
  **/Icon?
 
 
150
  /huggingface_tokenizers_cache
151
  /llama-factory/huggingface_tokenizers_cache
152
  **/Icon?
153
+ llama-factory/data/mgtv_train.json
llama-factory/README.md ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![# LLaMA Factory](assets/logo.png)
2
+
3
+ [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
4
+ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
5
+ [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
6
+ [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
7
+ [![Citation](https://img.shields.io/badge/citation-72-green)](#projects-using-llama-factory)
8
+ [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
9
+ [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
10
+ [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
11
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
12
+ [![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
13
+ [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
14
+ [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
15
+
16
+ [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
17
+
18
+ 👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
19
+
20
+ \[ English | [中文](README_zh.md) \]
21
+
22
+ **Fine-tuning a large language model can be easy as...**
23
+
24
+ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
25
+
26
+ Choose your path:
27
+
28
+ - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
29
+ - **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
30
+ - **Local machine**: Please refer to [usage](#getting-started)
31
+
32
+ ## Table of Contents
33
+
34
+ - [Features](#features)
35
+ - [Benchmark](#benchmark)
36
+ - [Changelog](#changelog)
37
+ - [Supported Models](#supported-models)
38
+ - [Supported Training Approaches](#supported-training-approaches)
39
+ - [Provided Datasets](#provided-datasets)
40
+ - [Requirement](#requirement)
41
+ - [Getting Started](#getting-started)
42
+ - [Projects using LLaMA Factory](#projects-using-llama-factory)
43
+ - [License](#license)
44
+ - [Citation](#citation)
45
+ - [Acknowledgement](#acknowledgement)
46
+
47
+ ## Features
48
+
49
+ - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
50
+ - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
51
+ - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
52
+ - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
53
+ - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
54
+ - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
55
+ - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
56
+
57
+ ## Benchmark
58
+
59
+ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
60
+
61
+ ![benchmark](assets/benchmark.svg)
62
+
63
+ <details><summary>Definitions</summary>
64
+
65
+ - **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
66
+ - **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
67
+ - **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
68
+ - We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
69
+
70
+ </details>
71
+
72
+ ## Changelog
73
+
74
+ [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
75
+
76
+ [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
77
+
78
+ [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
79
+
80
+ <details><summary>Full Changelog</summary>
81
+
82
+ [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
83
+
84
+ [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
85
+
86
+ [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
87
+
88
+ [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
89
+
90
+ [24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
91
+
92
+ [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
93
+
94
+ [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage.
95
+
96
+ [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
97
+
98
+ [24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
99
+
100
+ [24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
101
+
102
+ [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
103
+
104
+ [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
105
+
106
+ [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage.
107
+
108
+ [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
109
+
110
+ [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
111
+
112
+ [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
113
+
114
+ [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
115
+
116
+ [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
117
+
118
+ [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
119
+
120
+ [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
121
+
122
+ [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
123
+
124
+ [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
125
+
126
+ [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
127
+
128
+ [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
129
+
130
+ [23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
131
+
132
+ [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
133
+
134
+ [23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
135
+
136
+ [23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
137
+
138
+ [23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
139
+
140
+ [23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
141
+
142
+ [23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
143
+
144
+ [23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
145
+
146
+ [23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
147
+
148
+ [23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
149
+
150
+ </details>
151
+
152
+ ## Supported Models
153
+
154
+ | Model | Model size | Template |
155
+ | ------------------------------------------------------------ | -------------------------------- | --------- |
156
+ | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
157
+ | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
158
+ | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
159
+ | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
160
+ | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
161
+ | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
162
+ | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
163
+ | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
164
+ | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
165
+ | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
166
+ | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
167
+ | [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
168
+ | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
169
+ | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
170
+ | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
171
+ | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
172
+ | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
173
+ | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
174
+ | [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
175
+ | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
176
+ | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
177
+ | [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
178
+ | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
179
+ | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
180
+
181
+ > [!NOTE]
182
+ > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
183
+ >
184
+ > Remember to use the **SAME** template in training and inference.
185
+
186
+ Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
187
+
188
+ You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
189
+
190
+ ## Supported Training Approaches
191
+
192
+ | Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
193
+ | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
194
+ | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
195
+ | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
196
+ | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
197
+ | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
198
+ | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
199
+ | KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
200
+ | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
201
+ | SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
202
+
203
+ ## Provided Datasets
204
+
205
+ <details><summary>Pre-training datasets</summary>
206
+
207
+ - [Wiki Demo (en)](data/wiki_demo.txt)
208
+ - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
209
+ - [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
210
+ - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
211
+ - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
212
+ - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
213
+ - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
214
+ - [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
215
+ - [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
216
+ - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
217
+ - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
218
+
219
+ </details>
220
+
221
+ <details><summary>Supervised fine-tuning datasets</summary>
222
+
223
+ - [Identity (en&zh)](data/identity.json)
224
+ - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
225
+ - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
226
+ - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
227
+ - [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
228
+ - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
229
+ - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
230
+ - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
231
+ - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
232
+ - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
233
+ - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
234
+ - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
235
+ - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
236
+ - [UltraChat (en)](https://github.com/thunlp/UltraChat)
237
+ - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
238
+ - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
239
+ - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
240
+ - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
241
+ - [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
242
+ - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
243
+ - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
244
+ - [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
245
+ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
246
+ - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
247
+ - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
248
+ - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
249
+ - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
250
+ - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
251
+ - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
252
+ - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
253
+ - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
254
+ - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
255
+ - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
256
+ - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
257
+ - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
258
+ - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
259
+ - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
260
+ - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
261
+ - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
262
+ - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
263
+ - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
264
+ - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
265
+ - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
266
+ - [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
267
+ - [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
268
+ - [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
269
+ - [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
270
+ - [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
271
+ - [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
272
+
273
+ </details>
274
+
275
+ <details><summary>Preference datasets</summary>
276
+
277
+ - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
278
+ - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
279
+ - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
280
+ - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
281
+ - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
282
+ - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
283
+ - [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
284
+
285
+ </details>
286
+
287
+ Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
288
+
289
+ ```bash
290
+ pip install --upgrade huggingface_hub
291
+ huggingface-cli login
292
+ ```
293
+
294
+ ## Requirement
295
+
296
+ | Mandatory | Minimum | Recommend |
297
+ | ------------ | ------- | --------- |
298
+ | python | 3.8 | 3.11 |
299
+ | torch | 1.13.1 | 2.3.0 |
300
+ | transformers | 4.41.2 | 4.41.2 |
301
+ | datasets | 2.16.0 | 2.19.2 |
302
+ | accelerate | 0.30.1 | 0.30.1 |
303
+ | peft | 0.11.1 | 0.11.1 |
304
+ | trl | 0.8.6 | 0.9.4 |
305
+
306
+ | Optional | Minimum | Recommend |
307
+ | ------------ | ------- | --------- |
308
+ | CUDA | 11.6 | 12.2 |
309
+ | deepspeed | 0.10.0 | 0.14.0 |
310
+ | bitsandbytes | 0.39.0 | 0.43.1 |
311
+ | vllm | 0.4.3 | 0.4.3 |
312
+ | flash-attn | 2.3.0 | 2.5.9 |
313
+
314
+ ### Hardware Requirement
315
+
316
+ \* *estimated*
317
+
318
+ | Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
319
+ | ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
320
+ | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
321
+ | Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
322
+ | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
323
+ | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
324
+ | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
325
+ | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
326
+ | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
327
+
328
+ ## Getting Started
329
+
330
+ ### Installation
331
+
332
+ > [!IMPORTANT]
333
+ > Installation is mandatory.
334
+
335
+ ```bash
336
+ git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
337
+ cd LLaMA-Factory
338
+ pip install -e ".[torch,metrics]"
339
+ ```
340
+
341
+ Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
342
+
343
+ > [!TIP]
344
+ > Use `pip install --no-deps -e .` to resolve package conflicts.
345
+
346
+ <details><summary>For Windows users</summary>
347
+
348
+ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
349
+
350
+ ```bash
351
+ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
352
+ ```
353
+
354
+ To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
355
+
356
+ </details>
357
+
358
+ <details><summary>For Ascend NPU users</summary>
359
+
360
+ To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
361
+
362
+ ```bash
363
+ # replace the url according to your CANN version and devices
364
+ # install CANN Toolkit
365
+ wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
366
+ bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
367
+
368
+ # install CANN Kernels
369
+ wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
370
+ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
371
+
372
+ # set env variables
373
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
374
+ ```
375
+
376
+ | Requirement | Minimum | Recommend |
377
+ | ------------ | ------- | ----------- |
378
+ | CANN | 8.0.RC1 | 8.0.RC1 |
379
+ | torch | 2.1.0 | 2.1.0 |
380
+ | torch-npu | 2.1.0 | 2.1.0.post3 |
381
+ | deepspeed | 0.13.2 | 0.13.2 |
382
+
383
+ Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
384
+
385
+ If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
386
+
387
+ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
388
+
389
+ </details>
390
+
391
+ ### Data Preparation
392
+
393
+ Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
394
+
395
+ > [!NOTE]
396
+ > Please update `data/dataset_info.json` to use your custom dataset.
397
+
398
+ ### Quickstart
399
+
400
+ Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
401
+
402
+ ```bash
403
+ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
404
+ llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
405
+ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
406
+ ```
407
+
408
+ See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
409
+
410
+ > [!TIP]
411
+ > Use `llamafactory-cli help` to show help information.
412
+
413
+ ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
414
+
415
+ ```bash
416
+ llamafactory-cli webui
417
+ ```
418
+
419
+ ### Build Docker
420
+
421
+ For CUDA users:
422
+
423
+ ```bash
424
+ cd docker/docker-cuda/
425
+ docker-compose up -d
426
+ docker-compose exec llamafactory bash
427
+ ```
428
+
429
+ For Ascend NPU users:
430
+
431
+ ```bash
432
+ cd docker/docker-npu/
433
+ docker-compose up -d
434
+ docker-compose exec llamafactory bash
435
+ ```
436
+
437
+ <details><summary>Build without Docker Compose</summary>
438
+
439
+ For CUDA users:
440
+
441
+ ```bash
442
+ docker build -f ./docker/docker-cuda/Dockerfile \
443
+ --build-arg INSTALL_BNB=false \
444
+ --build-arg INSTALL_VLLM=false \
445
+ --build-arg INSTALL_DEEPSPEED=false \
446
+ --build-arg INSTALL_FLASHATTN=false \
447
+ --build-arg PIP_INDEX=https://pypi.org/simple \
448
+ -t llamafactory:latest .
449
+
450
+ docker run -dit --gpus=all \
451
+ -v ./hf_cache:/root/.cache/huggingface \
452
+ -v ./ms_cache:/root/.cache/modelscope \
453
+ -v ./data:/app/data \
454
+ -v ./output:/app/output \
455
+ -p 7860:7860 \
456
+ -p 8000:8000 \
457
+ --shm-size 16G \
458
+ --name llamafactory \
459
+ llamafactory:latest
460
+
461
+ docker exec -it llamafactory bash
462
+ ```
463
+
464
+ For Ascend NPU users:
465
+
466
+ ```bash
467
+ # Choose docker image upon your environment
468
+ docker build -f ./docker/docker-npu/Dockerfile \
469
+ --build-arg INSTALL_DEEPSPEED=false \
470
+ --build-arg PIP_INDEX=https://pypi.org/simple \
471
+ -t llamafactory:latest .
472
+
473
+ # Change `device` upon your resources
474
+ docker run -dit \
475
+ -v ./hf_cache:/root/.cache/huggingface \
476
+ -v ./ms_cache:/root/.cache/modelscope \
477
+ -v ./data:/app/data \
478
+ -v ./output:/app/output \
479
+ -v /usr/local/dcmi:/usr/local/dcmi \
480
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
481
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
482
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
483
+ -p 7860:7860 \
484
+ -p 8000:8000 \
485
+ --device /dev/davinci0 \
486
+ --device /dev/davinci_manager \
487
+ --device /dev/devmm_svm \
488
+ --device /dev/hisi_hdc \
489
+ --shm-size 16G \
490
+ --name llamafactory \
491
+ llamafactory:latest
492
+
493
+ docker exec -it llamafactory bash
494
+ ```
495
+
496
+ </details>
497
+
498
+ <details><summary>Details about volume</summary>
499
+
500
+ - hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
501
+ - data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
502
+ - output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
503
+
504
+ </details>
505
+
506
+ ### Deploy with OpenAI-style API and vLLM
507
+
508
+ ```bash
509
+ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
510
+ ```
511
+
512
+ > [!TIP]
513
+ > Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
514
+
515
+ ### Download from ModelScope Hub
516
+
517
+ If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
518
+
519
+ ```bash
520
+ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
521
+ ```
522
+
523
+ Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
524
+
525
+ ### Use W&B Logger
526
+
527
+ To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
528
+
529
+ ```yaml
530
+ report_to: wandb
531
+ run_name: test_run # optional
532
+ ```
533
+
534
+ Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
535
+
536
+ ## Projects using LLaMA Factory
537
+
538
+ If you have a project that should be incorporated, please contact via email or create a pull request.
539
+
540
+ <details><summary>Click to show</summary>
541
+
542
+ 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
543
+ 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
544
+ 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
545
+ 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
546
+ 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
547
+ 1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
548
+ 1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
549
+ 1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
550
+ 1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
551
+ 1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
552
+ 1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
553
+ 1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
554
+ 1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
555
+ 1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
556
+ 1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
557
+ 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
558
+ 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
559
+ 1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
560
+ 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
561
+ 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
562
+ 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
563
+ 1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
564
+ 1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
565
+ 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
566
+ 1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
567
+ 1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
568
+ 1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
569
+ 1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
570
+ 1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
571
+ 1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
572
+ 1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
573
+ 1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
574
+ 1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
575
+ 1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
576
+ 1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
577
+ 1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
578
+ 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
579
+ 1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
580
+ 1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
581
+ 1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
582
+ 1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
583
+ 1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
584
+ 1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
585
+ 1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
586
+ 1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
587
+ 1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
588
+ 1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
589
+ 1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
590
+ 1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
591
+ 1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
592
+ 1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
593
+ 1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
594
+ 1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
595
+ 1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
596
+ 1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
597
+ 1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
598
+ 1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
599
+ 1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
600
+ 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
601
+ 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
602
+ 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
603
+ 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
604
+ 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
605
+ 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
606
+ 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
607
+ 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
608
+ 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
609
+ 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
610
+ 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
611
+ 1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
612
+ 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
613
+ 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
614
+
615
+ </details>
616
+
617
+ ## License
618
+
619
+ This repository is licensed under the [Apache-2.0 License](LICENSE).
620
+
621
+ Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
622
+
623
+ ## Citation
624
+
625
+ If this work is helpful, please kindly cite as:
626
+
627
+ ```bibtex
628
+ @inproceedings{zheng2024llamafactory,
629
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
630
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
631
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
632
+ address={Bangkok, Thailand},
633
+ publisher={Association for Computational Linguistics},
634
+ year={2024},
635
+ url={http://arxiv.org/abs/2403.13372}
636
+ }
637
+ ```
638
+
639
+ ## Acknowledgement
640
+
641
+ This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
642
+
643
+ ## Star History
644
+
645
+ ![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
llama-factory/pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.ruff]
6
+ target-version = "py38"
7
+ line-length = 119
8
+ indent-width = 4
9
+
10
+ [tool.ruff.lint]
11
+ ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
12
+ select = ["C", "E", "F", "I", "W"]
13
+
14
+ [tool.ruff.lint.isort]
15
+ lines-after-imports = 2
16
+ known-first-party = ["llamafactory"]
17
+ known-third-party = [
18
+ "accelerate",
19
+ "datasets",
20
+ "gradio",
21
+ "numpy",
22
+ "peft",
23
+ "torch",
24
+ "transformers",
25
+ "trl"
26
+ ]
27
+
28
+ [tool.ruff.format]
29
+ quote-style = "double"
30
+ indent-style = "space"
31
+ docstring-code-format = true
32
+ skip-magic-trailing-comma = false
33
+ line-ending = "auto"
llama-factory/requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.41.2
2
+ datasets>=2.16.0
3
+ accelerate>=0.30.1
4
+ peft>=0.11.1
5
+ trl>=0.8.6
6
+ gradio>=4.0.0
7
+ pandas>=2.0.0
8
+ scipy
9
+ einops
10
+ sentencepiece
11
+ tiktoken
12
+ protobuf
13
+ uvicorn
14
+ pydantic
15
+ fastapi
16
+ sse-starlette
17
+ matplotlib>=3.7.0
18
+ fire
19
+ packaging
20
+ pyyaml
21
+ numpy<2.0.0
llama-factory/setup.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+
18
+ from setuptools import find_packages, setup
19
+
20
+
21
+ def get_version():
22
+ with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
23
+ file_content = f.read()
24
+ pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
25
+ (version,) = re.findall(pattern, file_content)
26
+ return version
27
+
28
+
29
+ def get_requires():
30
+ with open("requirements.txt", "r", encoding="utf-8") as f:
31
+ file_content = f.read()
32
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
33
+ return lines
34
+
35
+
36
+ extra_require = {
37
+ "torch": ["torch>=1.13.1"],
38
+ "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
39
+ "metrics": ["nltk", "jieba", "rouge-chinese"],
40
+ "deepspeed": ["deepspeed>=0.10.0"],
41
+ "bitsandbytes": ["bitsandbytes>=0.39.0"],
42
+ "hqq": ["hqq"],
43
+ "eetq": ["eetq"],
44
+ "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
45
+ "awq": ["autoawq"],
46
+ "aqlm": ["aqlm[gpu]>=1.1.0"],
47
+ "vllm": ["vllm>=0.4.3"],
48
+ "galore": ["galore-torch"],
49
+ "badam": ["badam>=1.2.1"],
50
+ "qwen": ["transformers_stream_generator"],
51
+ "modelscope": ["modelscope"],
52
+ "dev": ["ruff", "pytest"],
53
+ }
54
+
55
+
56
+ def main():
57
+ setup(
58
+ name="llamafactory",
59
+ version=get_version(),
60
+ author="hiyouga",
61
+ author_email="hiyouga" "@" "buaa.edu.cn",
62
+ description="Easy-to-use LLM fine-tuning framework",
63
+ long_description=open("README.md", "r", encoding="utf-8").read(),
64
+ long_description_content_type="text/markdown",
65
+ keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
66
+ license="Apache 2.0 License",
67
+ url="https://github.com/hiyouga/LLaMA-Factory",
68
+ package_dir={"": "src"},
69
+ packages=find_packages("src"),
70
+ python_requires=">=3.8.0",
71
+ install_requires=get_requires(),
72
+ extras_require=extra_require,
73
+ entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
74
+ classifiers=[
75
+ "Development Status :: 4 - Beta",
76
+ "Intended Audience :: Developers",
77
+ "Intended Audience :: Education",
78
+ "Intended Audience :: Science/Research",
79
+ "License :: OSI Approved :: Apache Software License",
80
+ "Operating System :: OS Independent",
81
+ "Programming Language :: Python :: 3",
82
+ "Programming Language :: Python :: 3.8",
83
+ "Programming Language :: Python :: 3.9",
84
+ "Programming Language :: Python :: 3.10",
85
+ "Programming Language :: Python :: 3.11",
86
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
87
+ ],
88
+ )
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
llama-factory/src/api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import uvicorn
18
+
19
+ from llamafactory.api.app import create_app
20
+ from llamafactory.chat import ChatModel
21
+
22
+
23
+ def main():
24
+ chat_model = ChatModel()
25
+ app = create_app(chat_model)
26
+ api_host = os.environ.get("API_HOST", "0.0.0.0")
27
+ api_port = int(os.environ.get("API_PORT", "8000"))
28
+ print("Visit http://localhost:{}/docs for API document.".format(api_port))
29
+ uvicorn.run(app, host=api_host, port=api_port)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
llama-factory/src/llamafactory/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""
16
+ Efficient fine-tuning of large language models.
17
+
18
+ Level:
19
+ api, webui > chat, eval, train > data, model > hparams > extras
20
+
21
+ Dependency graph:
22
+ main:
23
+ transformers>=4.41.2
24
+ datasets>=2.16.0
25
+ accelerate>=0.30.1
26
+ peft>=0.11.1
27
+ trl>=0.8.6
28
+ attention:
29
+ transformers>=4.42.4 (gemma+fa2)
30
+ longlora:
31
+ transformers>=4.41.2,<=4.42.4
32
+ packing:
33
+ transformers>=4.41.2,<=4.42.4
34
+ patcher:
35
+ transformers==4.41.2 (chatglm)
36
+ """
37
+
38
+ from .cli import VERSION
39
+
40
+
41
+ __version__ = VERSION
llama-factory/src/llamafactory/api/__init__.py ADDED
File without changes
llama-factory/src/llamafactory/api/app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from contextlib import asynccontextmanager
17
+ from typing import Optional
18
+
19
+ from typing_extensions import Annotated
20
+
21
+ from ..chat import ChatModel
22
+ from ..extras.misc import torch_gc
23
+ from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
24
+ from .chat import (
25
+ create_chat_completion_response,
26
+ create_score_evaluation_response,
27
+ create_stream_chat_completion_response,
28
+ )
29
+ from .protocol import (
30
+ ChatCompletionRequest,
31
+ ChatCompletionResponse,
32
+ ModelCard,
33
+ ModelList,
34
+ ScoreEvaluationRequest,
35
+ ScoreEvaluationResponse,
36
+ )
37
+
38
+
39
+ if is_fastapi_available():
40
+ from fastapi import Depends, FastAPI, HTTPException, status
41
+ from fastapi.middleware.cors import CORSMiddleware
42
+ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
43
+
44
+
45
+ if is_starlette_available():
46
+ from sse_starlette import EventSourceResponse
47
+
48
+
49
+ if is_uvicorn_available():
50
+ import uvicorn
51
+
52
+
53
+ @asynccontextmanager
54
+ async def lifespan(app: "FastAPI"): # collects GPU memory
55
+ yield
56
+ torch_gc()
57
+
58
+
59
+ def create_app(chat_model: "ChatModel") -> "FastAPI":
60
+ app = FastAPI(lifespan=lifespan)
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=["*"],
64
+ allow_credentials=True,
65
+ allow_methods=["*"],
66
+ allow_headers=["*"],
67
+ )
68
+ api_key = os.environ.get("API_KEY")
69
+ security = HTTPBearer(auto_error=False)
70
+
71
+ async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
72
+ if api_key and (auth is None or auth.credentials != api_key):
73
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
74
+
75
+ @app.get(
76
+ "/v1/models",
77
+ response_model=ModelList,
78
+ status_code=status.HTTP_200_OK,
79
+ dependencies=[Depends(verify_api_key)],
80
+ )
81
+ async def list_models():
82
+ model_card = ModelCard(id="gpt-3.5-turbo")
83
+ return ModelList(data=[model_card])
84
+
85
+ @app.post(
86
+ "/v1/chat/completions",
87
+ response_model=ChatCompletionResponse,
88
+ status_code=status.HTTP_200_OK,
89
+ dependencies=[Depends(verify_api_key)],
90
+ )
91
+ async def create_chat_completion(request: ChatCompletionRequest):
92
+ if not chat_model.engine.can_generate:
93
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
94
+
95
+ if request.stream:
96
+ generate = create_stream_chat_completion_response(request, chat_model)
97
+ return EventSourceResponse(generate, media_type="text/event-stream")
98
+ else:
99
+ return await create_chat_completion_response(request, chat_model)
100
+
101
+ @app.post(
102
+ "/v1/score/evaluation",
103
+ response_model=ScoreEvaluationResponse,
104
+ status_code=status.HTTP_200_OK,
105
+ dependencies=[Depends(verify_api_key)],
106
+ )
107
+ async def create_score_evaluation(request: ScoreEvaluationRequest):
108
+ if chat_model.engine.can_generate:
109
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
110
+
111
+ return await create_score_evaluation_response(request, chat_model)
112
+
113
+ return app
114
+
115
+
116
+ def run_api() -> None:
117
+ chat_model = ChatModel()
118
+ app = create_app(chat_model)
119
+ api_host = os.environ.get("API_HOST", "0.0.0.0")
120
+ api_port = int(os.environ.get("API_PORT", "8000"))
121
+ print("Visit http://localhost:{}/docs for API document.".format(api_port))
122
+ uvicorn.run(app, host=api_host, port=api_port)
llama-factory/src/llamafactory/api/chat.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import io
17
+ import json
18
+ import os
19
+ import uuid
20
+ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
21
+
22
+ from ..data import Role as DataRole
23
+ from ..extras.logging import get_logger
24
+ from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
25
+ from .common import dictify, jsonify
26
+ from .protocol import (
27
+ ChatCompletionMessage,
28
+ ChatCompletionResponse,
29
+ ChatCompletionResponseChoice,
30
+ ChatCompletionResponseUsage,
31
+ ChatCompletionStreamResponse,
32
+ ChatCompletionStreamResponseChoice,
33
+ Finish,
34
+ Function,
35
+ FunctionCall,
36
+ Role,
37
+ ScoreEvaluationResponse,
38
+ )
39
+
40
+
41
+ if is_fastapi_available():
42
+ from fastapi import HTTPException, status
43
+
44
+
45
+ if is_pillow_available():
46
+ from PIL import Image
47
+
48
+
49
+ if is_requests_available():
50
+ import requests
51
+
52
+
53
+ if TYPE_CHECKING:
54
+ from numpy.typing import NDArray
55
+
56
+ from ..chat import ChatModel
57
+ from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
58
+
59
+
60
+ logger = get_logger(__name__)
61
+ ROLE_MAPPING = {
62
+ Role.USER: DataRole.USER.value,
63
+ Role.ASSISTANT: DataRole.ASSISTANT.value,
64
+ Role.SYSTEM: DataRole.SYSTEM.value,
65
+ Role.FUNCTION: DataRole.FUNCTION.value,
66
+ Role.TOOL: DataRole.OBSERVATION.value,
67
+ }
68
+
69
+
70
+ def _process_request(
71
+ request: "ChatCompletionRequest",
72
+ ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
73
+ logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
74
+
75
+ if len(request.messages) == 0:
76
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
77
+
78
+ if request.messages[0].role == Role.SYSTEM:
79
+ system = request.messages.pop(0).content
80
+ else:
81
+ system = None
82
+
83
+ if len(request.messages) % 2 == 0:
84
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
85
+
86
+ input_messages = []
87
+ image = None
88
+ for i, message in enumerate(request.messages):
89
+ if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
90
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
91
+ elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
92
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
93
+
94
+ if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
95
+ tool_calls = [
96
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
97
+ for tool_call in message.tool_calls
98
+ ]
99
+ content = json.dumps(tool_calls, ensure_ascii=False)
100
+ input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
101
+ elif isinstance(message.content, list):
102
+ for input_item in message.content:
103
+ if input_item.type == "text":
104
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
105
+ else:
106
+ image_url = input_item.image_url.url
107
+ if image_url.startswith("data:image"): # base64 image
108
+ image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
109
+ image_path = io.BytesIO(image_data)
110
+ elif os.path.isfile(image_url): # local file
111
+ image_path = open(image_url, "rb")
112
+ else: # web uri
113
+ image_path = requests.get(image_url, stream=True).raw
114
+
115
+ image = Image.open(image_path).convert("RGB")
116
+ else:
117
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
118
+
119
+ tool_list = request.tools
120
+ if isinstance(tool_list, list) and len(tool_list):
121
+ try:
122
+ tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
123
+ except json.JSONDecodeError:
124
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
125
+ else:
126
+ tools = None
127
+
128
+ return input_messages, system, tools, image
129
+
130
+
131
+ def _create_stream_chat_completion_chunk(
132
+ completion_id: str,
133
+ model: str,
134
+ delta: "ChatCompletionMessage",
135
+ index: Optional[int] = 0,
136
+ finish_reason: Optional["Finish"] = None,
137
+ ) -> str:
138
+ choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
139
+ chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
140
+ return jsonify(chunk)
141
+
142
+
143
+ async def create_chat_completion_response(
144
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
145
+ ) -> "ChatCompletionResponse":
146
+ completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
147
+ input_messages, system, tools, image = _process_request(request)
148
+ responses = await chat_model.achat(
149
+ input_messages,
150
+ system,
151
+ tools,
152
+ image,
153
+ do_sample=request.do_sample,
154
+ temperature=request.temperature,
155
+ top_p=request.top_p,
156
+ max_new_tokens=request.max_tokens,
157
+ num_return_sequences=request.n,
158
+ stop=request.stop,
159
+ )
160
+
161
+ prompt_length, response_length = 0, 0
162
+ choices = []
163
+ for i, response in enumerate(responses):
164
+ if tools:
165
+ result = chat_model.engine.template.extract_tool(response.response_text)
166
+ else:
167
+ result = response.response_text
168
+
169
+ if isinstance(result, list):
170
+ tool_calls = []
171
+ for tool in result:
172
+ function = Function(name=tool[0], arguments=tool[1])
173
+ tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
174
+
175
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
176
+ finish_reason = Finish.TOOL
177
+ else:
178
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
179
+ finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
180
+
181
+ choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
182
+ prompt_length = response.prompt_length
183
+ response_length += response.response_length
184
+
185
+ usage = ChatCompletionResponseUsage(
186
+ prompt_tokens=prompt_length,
187
+ completion_tokens=response_length,
188
+ total_tokens=prompt_length + response_length,
189
+ )
190
+
191
+ return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
192
+
193
+
194
+ async def create_stream_chat_completion_response(
195
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
196
+ ) -> AsyncGenerator[str, None]:
197
+ completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
198
+ input_messages, system, tools, image = _process_request(request)
199
+ if tools:
200
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
201
+
202
+ if request.n > 1:
203
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
204
+
205
+ yield _create_stream_chat_completion_chunk(
206
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
207
+ )
208
+ async for new_token in chat_model.astream_chat(
209
+ input_messages,
210
+ system,
211
+ tools,
212
+ image,
213
+ do_sample=request.do_sample,
214
+ temperature=request.temperature,
215
+ top_p=request.top_p,
216
+ max_new_tokens=request.max_tokens,
217
+ stop=request.stop,
218
+ ):
219
+ if len(new_token) != 0:
220
+ yield _create_stream_chat_completion_chunk(
221
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
222
+ )
223
+
224
+ yield _create_stream_chat_completion_chunk(
225
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
226
+ )
227
+ yield "[DONE]"
228
+
229
+
230
+ async def create_score_evaluation_response(
231
+ request: "ScoreEvaluationRequest", chat_model: "ChatModel"
232
+ ) -> "ScoreEvaluationResponse":
233
+ if len(request.messages) == 0:
234
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
235
+
236
+ scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
237
+ return ScoreEvaluationResponse(model=request.model, scores=scores)
llama-factory/src/llamafactory/api/common.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from typing import TYPE_CHECKING, Any, Dict
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ from pydantic import BaseModel
21
+
22
+
23
+ def dictify(data: "BaseModel") -> Dict[str, Any]:
24
+ try: # pydantic v2
25
+ return data.model_dump(exclude_unset=True)
26
+ except AttributeError: # pydantic v1
27
+ return data.dict(exclude_unset=True)
28
+
29
+
30
+ def jsonify(data: "BaseModel") -> str:
31
+ try: # pydantic v2
32
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
33
+ except AttributeError: # pydantic v1
34
+ return data.json(exclude_unset=True, ensure_ascii=False)
llama-factory/src/llamafactory/api/protocol.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+ from enum import Enum, unique
17
+ from typing import Any, Dict, List, Optional, Union
18
+
19
+ from pydantic import BaseModel, Field
20
+ from typing_extensions import Literal
21
+
22
+
23
+ @unique
24
+ class Role(str, Enum):
25
+ USER = "user"
26
+ ASSISTANT = "assistant"
27
+ SYSTEM = "system"
28
+ FUNCTION = "function"
29
+ TOOL = "tool"
30
+
31
+
32
+ @unique
33
+ class Finish(str, Enum):
34
+ STOP = "stop"
35
+ LENGTH = "length"
36
+ TOOL = "tool_calls"
37
+
38
+
39
+ class ModelCard(BaseModel):
40
+ id: str
41
+ object: Literal["model"] = "model"
42
+ created: int = Field(default_factory=lambda: int(time.time()))
43
+ owned_by: Literal["owner"] = "owner"
44
+
45
+
46
+ class ModelList(BaseModel):
47
+ object: Literal["list"] = "list"
48
+ data: List[ModelCard] = []
49
+
50
+
51
+ class Function(BaseModel):
52
+ name: str
53
+ arguments: str
54
+
55
+
56
+ class FunctionDefinition(BaseModel):
57
+ name: str
58
+ description: str
59
+ parameters: Dict[str, Any]
60
+
61
+
62
+ class FunctionAvailable(BaseModel):
63
+ type: Literal["function", "code_interpreter"] = "function"
64
+ function: Optional[FunctionDefinition] = None
65
+
66
+
67
+ class FunctionCall(BaseModel):
68
+ id: str
69
+ type: Literal["function"] = "function"
70
+ function: Function
71
+
72
+
73
+ class ImageURL(BaseModel):
74
+ url: str
75
+
76
+
77
+ class MultimodalInputItem(BaseModel):
78
+ type: Literal["text", "image_url"]
79
+ text: Optional[str] = None
80
+ image_url: Optional[ImageURL] = None
81
+
82
+
83
+ class ChatMessage(BaseModel):
84
+ role: Role
85
+ content: Optional[Union[str, List[MultimodalInputItem]]] = None
86
+ tool_calls: Optional[List[FunctionCall]] = None
87
+
88
+
89
+ class ChatCompletionMessage(BaseModel):
90
+ role: Optional[Role] = None
91
+ content: Optional[str] = None
92
+ tool_calls: Optional[List[FunctionCall]] = None
93
+
94
+
95
+ class ChatCompletionRequest(BaseModel):
96
+ model: str
97
+ messages: List[ChatMessage]
98
+ tools: Optional[List[FunctionAvailable]] = None
99
+ do_sample: Optional[bool] = None
100
+ temperature: Optional[float] = None
101
+ top_p: Optional[float] = None
102
+ n: int = 1
103
+ max_tokens: Optional[int] = None
104
+ stop: Optional[Union[str, List[str]]] = None
105
+ stream: bool = False
106
+
107
+
108
+ class ChatCompletionResponseChoice(BaseModel):
109
+ index: int
110
+ message: ChatCompletionMessage
111
+ finish_reason: Finish
112
+
113
+
114
+ class ChatCompletionStreamResponseChoice(BaseModel):
115
+ index: int
116
+ delta: ChatCompletionMessage
117
+ finish_reason: Optional[Finish] = None
118
+
119
+
120
+ class ChatCompletionResponseUsage(BaseModel):
121
+ prompt_tokens: int
122
+ completion_tokens: int
123
+ total_tokens: int
124
+
125
+
126
+ class ChatCompletionResponse(BaseModel):
127
+ id: str
128
+ object: Literal["chat.completion"] = "chat.completion"
129
+ created: int = Field(default_factory=lambda: int(time.time()))
130
+ model: str
131
+ choices: List[ChatCompletionResponseChoice]
132
+ usage: ChatCompletionResponseUsage
133
+
134
+
135
+ class ChatCompletionStreamResponse(BaseModel):
136
+ id: str
137
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
138
+ created: int = Field(default_factory=lambda: int(time.time()))
139
+ model: str
140
+ choices: List[ChatCompletionStreamResponseChoice]
141
+
142
+
143
+ class ScoreEvaluationRequest(BaseModel):
144
+ model: str
145
+ messages: List[str]
146
+ max_length: Optional[int] = None
147
+
148
+
149
+ class ScoreEvaluationResponse(BaseModel):
150
+ id: str
151
+ object: Literal["score.evaluation"] = "score.evaluation"
152
+ model: str
153
+ scores: List[float]
llama-factory/src/llamafactory/chat/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base_engine import BaseEngine
16
+ from .chat_model import ChatModel
17
+
18
+
19
+ __all__ = ["BaseEngine", "ChatModel"]
llama-factory/src/llamafactory/chat/base_engine.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from dataclasses import dataclass
17
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from numpy.typing import NDArray
22
+ from transformers import PreTrainedModel, PreTrainedTokenizer
23
+ from vllm import AsyncLLMEngine
24
+
25
+ from ..data import Template
26
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
27
+
28
+
29
+ @dataclass
30
+ class Response:
31
+ response_text: str
32
+ response_length: int
33
+ prompt_length: int
34
+ finish_reason: Literal["stop", "length"]
35
+
36
+
37
+ class BaseEngine(ABC):
38
+ model: Union["PreTrainedModel", "AsyncLLMEngine"]
39
+ tokenizer: "PreTrainedTokenizer"
40
+ can_generate: bool
41
+ template: "Template"
42
+ generating_args: Dict[str, Any]
43
+
44
+ @abstractmethod
45
+ def __init__(
46
+ self,
47
+ model_args: "ModelArguments",
48
+ data_args: "DataArguments",
49
+ finetuning_args: "FinetuningArguments",
50
+ generating_args: "GeneratingArguments",
51
+ ) -> None: ...
52
+
53
+ @abstractmethod
54
+ async def chat(
55
+ self,
56
+ messages: Sequence[Dict[str, str]],
57
+ system: Optional[str] = None,
58
+ tools: Optional[str] = None,
59
+ image: Optional["NDArray"] = None,
60
+ **input_kwargs,
61
+ ) -> List["Response"]: ...
62
+
63
+ @abstractmethod
64
+ async def stream_chat(
65
+ self,
66
+ messages: Sequence[Dict[str, str]],
67
+ system: Optional[str] = None,
68
+ tools: Optional[str] = None,
69
+ image: Optional["NDArray"] = None,
70
+ **input_kwargs,
71
+ ) -> AsyncGenerator[str, None]: ...
72
+
73
+ @abstractmethod
74
+ async def get_scores(
75
+ self,
76
+ batch_input: List[str],
77
+ **input_kwargs,
78
+ ) -> List[float]: ...
llama-factory/src/llamafactory/chat/chat_model.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 THUDM and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the THUDM's ChatGLM implementation.
4
+ # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import asyncio
19
+ import os
20
+ from threading import Thread
21
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
22
+
23
+ from ..extras.misc import torch_gc
24
+ from ..hparams import get_infer_args
25
+ from .hf_engine import HuggingfaceEngine
26
+ from .vllm_engine import VllmEngine
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from numpy.typing import NDArray
31
+
32
+ from .base_engine import BaseEngine, Response
33
+
34
+
35
+ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
36
+ asyncio.set_event_loop(loop)
37
+ loop.run_forever()
38
+
39
+
40
+ class ChatModel:
41
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
42
+ model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
43
+ if model_args.infer_backend == "huggingface":
44
+ self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
45
+ elif model_args.infer_backend == "vllm":
46
+ self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
47
+ else:
48
+ raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
49
+
50
+ self._loop = asyncio.new_event_loop()
51
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
52
+ self._thread.start()
53
+
54
+ def chat(
55
+ self,
56
+ messages: Sequence[Dict[str, str]],
57
+ system: Optional[str] = None,
58
+ tools: Optional[str] = None,
59
+ image: Optional["NDArray"] = None,
60
+ **input_kwargs,
61
+ ) -> List["Response"]:
62
+ task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
63
+ return task.result()
64
+
65
+ async def achat(
66
+ self,
67
+ messages: Sequence[Dict[str, str]],
68
+ system: Optional[str] = None,
69
+ tools: Optional[str] = None,
70
+ image: Optional["NDArray"] = None,
71
+ **input_kwargs,
72
+ ) -> List["Response"]:
73
+ return await self.engine.chat(messages, system, tools, image, **input_kwargs)
74
+
75
+ def stream_chat(
76
+ self,
77
+ messages: Sequence[Dict[str, str]],
78
+ system: Optional[str] = None,
79
+ tools: Optional[str] = None,
80
+ image: Optional["NDArray"] = None,
81
+ **input_kwargs,
82
+ ) -> Generator[str, None, None]:
83
+ generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
84
+ while True:
85
+ try:
86
+ task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
87
+ yield task.result()
88
+ except StopAsyncIteration:
89
+ break
90
+
91
+ async def astream_chat(
92
+ self,
93
+ messages: Sequence[Dict[str, str]],
94
+ system: Optional[str] = None,
95
+ tools: Optional[str] = None,
96
+ image: Optional["NDArray"] = None,
97
+ **input_kwargs,
98
+ ) -> AsyncGenerator[str, None]:
99
+ async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
100
+ yield new_token
101
+
102
+ def get_scores(
103
+ self,
104
+ batch_input: List[str],
105
+ **input_kwargs,
106
+ ) -> List[float]:
107
+ task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
108
+ return task.result()
109
+
110
+ async def aget_scores(
111
+ self,
112
+ batch_input: List[str],
113
+ **input_kwargs,
114
+ ) -> List[float]:
115
+ return await self.engine.get_scores(batch_input, **input_kwargs)
116
+
117
+
118
+ def run_chat() -> None:
119
+ if os.name != "nt":
120
+ try:
121
+ import readline # noqa: F401
122
+ except ImportError:
123
+ print("Install `readline` for a better experience.")
124
+
125
+ chat_model = ChatModel()
126
+ messages = []
127
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
128
+
129
+ while True:
130
+ try:
131
+ query = input("\nUser: ")
132
+ except UnicodeDecodeError:
133
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
134
+ continue
135
+ except Exception:
136
+ raise
137
+
138
+ if query.strip() == "exit":
139
+ break
140
+
141
+ if query.strip() == "clear":
142
+ messages = []
143
+ torch_gc()
144
+ print("History has been removed.")
145
+ continue
146
+
147
+ messages.append({"role": "user", "content": query})
148
+ print("Assistant: ", end="", flush=True)
149
+
150
+ response = ""
151
+ for new_text in chat_model.stream_chat(messages):
152
+ print(new_text, end="", flush=True)
153
+ response += new_text
154
+ print()
155
+ messages.append({"role": "assistant", "content": response})
llama-factory/src/llamafactory/chat/hf_engine.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import concurrent.futures
17
+ import os
18
+ from threading import Thread
19
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
20
+
21
+ import torch
22
+ from transformers import GenerationConfig, TextIteratorStreamer
23
+
24
+ from ..data import get_template_and_fix_tokenizer
25
+ from ..extras.logging import get_logger
26
+ from ..extras.misc import get_logits_processor
27
+ from ..model import load_model, load_tokenizer
28
+ from .base_engine import BaseEngine, Response
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from numpy.typing import NDArray
33
+ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
34
+ from transformers.image_processing_utils import BaseImageProcessor
35
+ from trl import PreTrainedModelWrapper
36
+
37
+ from ..data import Template
38
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
39
+
40
+
41
+ logger = get_logger(__name__)
42
+
43
+
44
+ class HuggingfaceEngine(BaseEngine):
45
+ def __init__(
46
+ self,
47
+ model_args: "ModelArguments",
48
+ data_args: "DataArguments",
49
+ finetuning_args: "FinetuningArguments",
50
+ generating_args: "GeneratingArguments",
51
+ ) -> None:
52
+ self.can_generate = finetuning_args.stage == "sft"
53
+ tokenizer_module = load_tokenizer(model_args)
54
+ self.tokenizer = tokenizer_module["tokenizer"]
55
+ self.processor = tokenizer_module["processor"]
56
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
57
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
58
+ self.model = load_model(
59
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
60
+ ) # must after fixing tokenizer to resize vocab
61
+ self.generating_args = generating_args.to_dict()
62
+ try:
63
+ asyncio.get_event_loop()
64
+ except RuntimeError:
65
+ logger.warning("There is no current event loop, creating a new one.")
66
+ loop = asyncio.new_event_loop()
67
+ asyncio.set_event_loop(loop)
68
+
69
+ self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
70
+
71
+ @staticmethod
72
+ def _process_args(
73
+ model: "PreTrainedModel",
74
+ tokenizer: "PreTrainedTokenizer",
75
+ processor: Optional["ProcessorMixin"],
76
+ template: "Template",
77
+ generating_args: Dict[str, Any],
78
+ messages: Sequence[Dict[str, str]],
79
+ system: Optional[str] = None,
80
+ tools: Optional[str] = None,
81
+ image: Optional["NDArray"] = None,
82
+ input_kwargs: Optional[Dict[str, Any]] = {},
83
+ ) -> Tuple[Dict[str, Any], int]:
84
+ if (
85
+ processor is not None
86
+ and image is not None
87
+ and not hasattr(processor, "image_seq_length")
88
+ and template.image_token not in messages[0]["content"]
89
+ ): # llava-like models
90
+ messages[0]["content"] = template.image_token + messages[0]["content"]
91
+
92
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
93
+ system = system or generating_args["default_system"]
94
+ pixel_values = None
95
+ prompt_ids, _ = template.encode_oneturn(
96
+ tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
97
+ )
98
+ if processor is not None and image is not None: # add image features
99
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
100
+ batch_feature = image_processor(image, return_tensors="pt")
101
+ pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
102
+ if hasattr(processor, "image_seq_length"): # paligemma models
103
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
104
+ prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
105
+
106
+ prompt_length = len(prompt_ids)
107
+ inputs = torch.tensor([prompt_ids], device=model.device)
108
+ attention_mask = torch.ones_like(inputs, dtype=torch.bool)
109
+
110
+ do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
111
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
112
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
113
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
114
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
115
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
116
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
117
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
118
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
119
+ stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
120
+
121
+ if stop is not None:
122
+ logger.warning("Stop parameter is not supported by the huggingface engine yet.")
123
+
124
+ generating_args = generating_args.copy()
125
+ generating_args.update(
126
+ dict(
127
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
128
+ temperature=temperature if temperature is not None else generating_args["temperature"],
129
+ top_p=top_p if top_p is not None else generating_args["top_p"],
130
+ top_k=top_k if top_k is not None else generating_args["top_k"],
131
+ num_return_sequences=num_return_sequences,
132
+ repetition_penalty=repetition_penalty
133
+ if repetition_penalty is not None
134
+ else generating_args["repetition_penalty"],
135
+ length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
136
+ eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
137
+ pad_token_id=tokenizer.pad_token_id,
138
+ )
139
+ )
140
+
141
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
142
+ generating_args["do_sample"] = True
143
+ generating_args["temperature"] = generating_args["temperature"] or 1.0
144
+
145
+ if not generating_args["temperature"]:
146
+ generating_args["do_sample"] = False
147
+
148
+ if not generating_args["do_sample"]:
149
+ generating_args.pop("temperature", None)
150
+ generating_args.pop("top_p", None)
151
+
152
+ if max_length:
153
+ generating_args.pop("max_new_tokens", None)
154
+ generating_args["max_length"] = max_length
155
+
156
+ if max_new_tokens:
157
+ generating_args.pop("max_length", None)
158
+ generating_args["max_new_tokens"] = max_new_tokens
159
+
160
+ gen_kwargs = dict(
161
+ inputs=inputs,
162
+ attention_mask=attention_mask,
163
+ generation_config=GenerationConfig(**generating_args),
164
+ logits_processor=get_logits_processor(),
165
+ )
166
+
167
+ if pixel_values is not None:
168
+ gen_kwargs["pixel_values"] = pixel_values
169
+
170
+ return gen_kwargs, prompt_length
171
+
172
+ @staticmethod
173
+ @torch.inference_mode()
174
+ def _chat(
175
+ model: "PreTrainedModel",
176
+ tokenizer: "PreTrainedTokenizer",
177
+ processor: Optional["ProcessorMixin"],
178
+ template: "Template",
179
+ generating_args: Dict[str, Any],
180
+ messages: Sequence[Dict[str, str]],
181
+ system: Optional[str] = None,
182
+ tools: Optional[str] = None,
183
+ image: Optional["NDArray"] = None,
184
+ input_kwargs: Optional[Dict[str, Any]] = {},
185
+ ) -> List["Response"]:
186
+ gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
187
+ model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
188
+ )
189
+ generate_output = model.generate(**gen_kwargs)
190
+ response_ids = generate_output[:, prompt_length:]
191
+ response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
192
+ results = []
193
+ for i in range(len(response)):
194
+ eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
195
+ response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
196
+ results.append(
197
+ Response(
198
+ response_text=response[i],
199
+ response_length=response_length,
200
+ prompt_length=prompt_length,
201
+ finish_reason="stop" if len(eos_index) else "length",
202
+ )
203
+ )
204
+
205
+ return results
206
+
207
+ @staticmethod
208
+ @torch.inference_mode()
209
+ def _stream_chat(
210
+ model: "PreTrainedModel",
211
+ tokenizer: "PreTrainedTokenizer",
212
+ processor: Optional["ProcessorMixin"],
213
+ template: "Template",
214
+ generating_args: Dict[str, Any],
215
+ messages: Sequence[Dict[str, str]],
216
+ system: Optional[str] = None,
217
+ tools: Optional[str] = None,
218
+ image: Optional["NDArray"] = None,
219
+ input_kwargs: Optional[Dict[str, Any]] = {},
220
+ ) -> Callable[[], str]:
221
+ gen_kwargs, _ = HuggingfaceEngine._process_args(
222
+ model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
223
+ )
224
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
225
+ gen_kwargs["streamer"] = streamer
226
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
227
+ thread.start()
228
+
229
+ def stream():
230
+ try:
231
+ return streamer.__next__()
232
+ except StopIteration:
233
+ raise StopAsyncIteration()
234
+
235
+ return stream
236
+
237
+ @staticmethod
238
+ @torch.inference_mode()
239
+ def _get_scores(
240
+ model: "PreTrainedModelWrapper",
241
+ tokenizer: "PreTrainedTokenizer",
242
+ batch_input: List[str],
243
+ input_kwargs: Optional[Dict[str, Any]] = {},
244
+ ) -> List[float]:
245
+ max_length = input_kwargs.pop("max_length", None)
246
+ device = getattr(model.pretrained_model, "device", "cuda")
247
+ inputs = tokenizer(
248
+ batch_input,
249
+ padding=True,
250
+ truncation=True,
251
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
252
+ return_tensors="pt",
253
+ add_special_tokens=True,
254
+ ).to(device)
255
+
256
+ input_ids: torch.Tensor = inputs["input_ids"]
257
+ _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
258
+
259
+ if getattr(model.config, "model_type", None) == "chatglm":
260
+ values = torch.transpose(values, 0, 1)
261
+
262
+ scores = []
263
+ for i in range(input_ids.size(0)):
264
+ end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
265
+ end_index = end_indexes[-1].item() if len(end_indexes) else 0
266
+ scores.append(values[i, end_index].nan_to_num().item())
267
+
268
+ return scores
269
+
270
+ async def chat(
271
+ self,
272
+ messages: Sequence[Dict[str, str]],
273
+ system: Optional[str] = None,
274
+ tools: Optional[str] = None,
275
+ image: Optional["NDArray"] = None,
276
+ **input_kwargs,
277
+ ) -> List["Response"]:
278
+ if not self.can_generate:
279
+ raise ValueError("The current model does not support `chat`.")
280
+
281
+ loop = asyncio.get_running_loop()
282
+ input_args = (
283
+ self.model,
284
+ self.tokenizer,
285
+ self.processor,
286
+ self.template,
287
+ self.generating_args,
288
+ messages,
289
+ system,
290
+ tools,
291
+ image,
292
+ input_kwargs,
293
+ )
294
+ async with self.semaphore:
295
+ with concurrent.futures.ThreadPoolExecutor() as pool:
296
+ return await loop.run_in_executor(pool, self._chat, *input_args)
297
+
298
+ async def stream_chat(
299
+ self,
300
+ messages: Sequence[Dict[str, str]],
301
+ system: Optional[str] = None,
302
+ tools: Optional[str] = None,
303
+ image: Optional["NDArray"] = None,
304
+ **input_kwargs,
305
+ ) -> AsyncGenerator[str, None]:
306
+ if not self.can_generate:
307
+ raise ValueError("The current model does not support `stream_chat`.")
308
+
309
+ loop = asyncio.get_running_loop()
310
+ input_args = (
311
+ self.model,
312
+ self.tokenizer,
313
+ self.processor,
314
+ self.template,
315
+ self.generating_args,
316
+ messages,
317
+ system,
318
+ tools,
319
+ image,
320
+ input_kwargs,
321
+ )
322
+ async with self.semaphore:
323
+ with concurrent.futures.ThreadPoolExecutor() as pool:
324
+ stream = self._stream_chat(*input_args)
325
+ while True:
326
+ try:
327
+ yield await loop.run_in_executor(pool, stream)
328
+ except StopAsyncIteration:
329
+ break
330
+
331
+ async def get_scores(
332
+ self,
333
+ batch_input: List[str],
334
+ **input_kwargs,
335
+ ) -> List[float]:
336
+ if self.can_generate:
337
+ raise ValueError("Cannot get scores using an auto-regressive model.")
338
+
339
+ loop = asyncio.get_running_loop()
340
+ input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
341
+ async with self.semaphore:
342
+ with concurrent.futures.ThreadPoolExecutor() as pool:
343
+ return await loop.run_in_executor(pool, self._get_scores, *input_args)
llama-factory/src/llamafactory/chat/vllm_engine.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import uuid
16
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
17
+
18
+ from ..data import get_template_and_fix_tokenizer
19
+ from ..extras.logging import get_logger
20
+ from ..extras.misc import get_device_count
21
+ from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
22
+ from ..model import load_config, load_tokenizer
23
+ from ..model.model_utils.quantization import QuantizationMethod
24
+ from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
25
+ from .base_engine import BaseEngine, Response
26
+
27
+
28
+ if is_vllm_available():
29
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
30
+ from vllm.lora.request import LoRARequest
31
+
32
+ if is_vllm_version_greater_than_0_5_1():
33
+ pass
34
+ elif is_vllm_version_greater_than_0_5():
35
+ from vllm.multimodal.image import ImagePixelData
36
+ else:
37
+ from vllm.sequence import MultiModalData
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from numpy.typing import NDArray
42
+ from transformers.image_processing_utils import BaseImageProcessor
43
+
44
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
45
+
46
+
47
+ logger = get_logger(__name__)
48
+
49
+
50
+ class VllmEngine(BaseEngine):
51
+ def __init__(
52
+ self,
53
+ model_args: "ModelArguments",
54
+ data_args: "DataArguments",
55
+ finetuning_args: "FinetuningArguments",
56
+ generating_args: "GeneratingArguments",
57
+ ) -> None:
58
+ config = load_config(model_args) # may download model from ms hub
59
+ if getattr(config, "quantization_config", None): # gptq models should use float16
60
+ quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
61
+ quant_method = quantization_config.get("quant_method", "")
62
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
63
+ model_args.infer_dtype = "float16"
64
+
65
+ self.can_generate = finetuning_args.stage == "sft"
66
+ tokenizer_module = load_tokenizer(model_args)
67
+ self.tokenizer = tokenizer_module["tokenizer"]
68
+ self.processor = tokenizer_module["processor"]
69
+ self.tokenizer.padding_side = "left"
70
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
71
+ self.generating_args = generating_args.to_dict()
72
+
73
+ engine_args = {
74
+ "model": model_args.model_name_or_path,
75
+ "trust_remote_code": True,
76
+ "download_dir": model_args.cache_dir,
77
+ "dtype": model_args.infer_dtype,
78
+ "max_model_len": model_args.vllm_maxlen,
79
+ "tensor_parallel_size": get_device_count() or 1,
80
+ "gpu_memory_utilization": model_args.vllm_gpu_util,
81
+ "disable_log_stats": True,
82
+ "disable_log_requests": True,
83
+ "enforce_eager": model_args.vllm_enforce_eager,
84
+ "enable_lora": model_args.adapter_name_or_path is not None,
85
+ "max_lora_rank": model_args.vllm_max_lora_rank,
86
+ }
87
+
88
+ if model_args.visual_inputs:
89
+ image_size = config.vision_config.image_size
90
+ patch_size = config.vision_config.patch_size
91
+ self.image_feature_size = (image_size // patch_size) ** 2
92
+ engine_args["image_input_type"] = "pixel_values"
93
+ engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
94
+ engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
95
+ engine_args["image_feature_size"] = self.image_feature_size
96
+ if getattr(config, "is_yi_vl_derived_model", None):
97
+ import vllm.model_executor.models.llava
98
+
99
+ logger.info("Detected Yi-VL model, applying projector patch.")
100
+ vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
101
+
102
+ self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
103
+ if model_args.adapter_name_or_path is not None:
104
+ self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
105
+ else:
106
+ self.lora_request = None
107
+
108
+ async def _generate(
109
+ self,
110
+ messages: Sequence[Dict[str, str]],
111
+ system: Optional[str] = None,
112
+ tools: Optional[str] = None,
113
+ image: Optional["NDArray"] = None,
114
+ **input_kwargs,
115
+ ) -> AsyncIterator["RequestOutput"]:
116
+ request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
117
+
118
+ if (
119
+ self.processor is not None
120
+ and image is not None
121
+ and not hasattr(self.processor, "image_seq_length")
122
+ and self.template.image_token not in messages[0]["content"]
123
+ ): # llava-like models (TODO: paligemma models)
124
+ messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
125
+
126
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
127
+ system = system or self.generating_args["default_system"]
128
+ prompt_ids, _ = self.template.encode_oneturn(
129
+ tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
130
+ )
131
+
132
+ if self.processor is not None and image is not None: # add image features
133
+ image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
134
+ pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
135
+ if is_vllm_version_greater_than_0_5_1():
136
+ multi_modal_data = {"image": pixel_values}
137
+ elif is_vllm_version_greater_than_0_5():
138
+ multi_modal_data = ImagePixelData(image=pixel_values)
139
+ else: # TODO: remove vllm 0.4.3 support
140
+ multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
141
+ else:
142
+ multi_modal_data = None
143
+
144
+ prompt_length = len(prompt_ids)
145
+
146
+ use_beam_search: bool = self.generating_args["num_beams"] > 1
147
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
148
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
149
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
150
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
151
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
152
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
153
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
154
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
155
+ stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
156
+
157
+ if "max_new_tokens" in self.generating_args:
158
+ max_tokens = self.generating_args["max_new_tokens"]
159
+ elif "max_length" in self.generating_args:
160
+ if self.generating_args["max_length"] > prompt_length:
161
+ max_tokens = self.generating_args["max_length"] - prompt_length
162
+ else:
163
+ max_tokens = 1
164
+
165
+ if max_length:
166
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
167
+
168
+ if max_new_tokens:
169
+ max_tokens = max_new_tokens
170
+
171
+ sampling_params = SamplingParams(
172
+ n=num_return_sequences,
173
+ repetition_penalty=(
174
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
175
+ )
176
+ or 1.0, # repetition_penalty must > 0
177
+ temperature=temperature if temperature is not None else self.generating_args["temperature"],
178
+ top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
179
+ top_k=top_k if top_k is not None else self.generating_args["top_k"],
180
+ use_beam_search=use_beam_search,
181
+ length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
182
+ stop=stop,
183
+ stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
184
+ max_tokens=max_tokens,
185
+ skip_special_tokens=True,
186
+ )
187
+
188
+ result_generator = self.model.generate(
189
+ inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
190
+ sampling_params=sampling_params,
191
+ request_id=request_id,
192
+ lora_request=self.lora_request,
193
+ )
194
+ return result_generator
195
+
196
+ async def chat(
197
+ self,
198
+ messages: Sequence[Dict[str, str]],
199
+ system: Optional[str] = None,
200
+ tools: Optional[str] = None,
201
+ image: Optional["NDArray"] = None,
202
+ **input_kwargs,
203
+ ) -> List["Response"]:
204
+ final_output = None
205
+ generator = await self._generate(messages, system, tools, image, **input_kwargs)
206
+ async for request_output in generator:
207
+ final_output = request_output
208
+
209
+ results = []
210
+ for output in final_output.outputs:
211
+ results.append(
212
+ Response(
213
+ response_text=output.text,
214
+ response_length=len(output.token_ids),
215
+ prompt_length=len(final_output.prompt_token_ids),
216
+ finish_reason=output.finish_reason,
217
+ )
218
+ )
219
+
220
+ return results
221
+
222
+ async def stream_chat(
223
+ self,
224
+ messages: Sequence[Dict[str, str]],
225
+ system: Optional[str] = None,
226
+ tools: Optional[str] = None,
227
+ image: Optional["NDArray"] = None,
228
+ **input_kwargs,
229
+ ) -> AsyncGenerator[str, None]:
230
+ generated_text = ""
231
+ generator = await self._generate(messages, system, tools, image, **input_kwargs)
232
+ async for result in generator:
233
+ delta_text = result.outputs[0].text[len(generated_text) :]
234
+ generated_text = result.outputs[0].text
235
+ yield delta_text
236
+
237
+ async def get_scores(
238
+ self,
239
+ batch_input: List[str],
240
+ **input_kwargs,
241
+ ) -> List[float]:
242
+ raise NotImplementedError("vLLM engine does not support get_scores.")
llama-factory/src/llamafactory/cli.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import random
17
+ import subprocess
18
+ import sys
19
+ from enum import Enum, unique
20
+
21
+ from . import launcher
22
+ from .api.app import run_api
23
+ from .chat.chat_model import run_chat
24
+ from .eval.evaluator import run_eval
25
+ from .extras.env import VERSION, print_env
26
+ from .extras.logging import get_logger
27
+ from .extras.misc import get_device_count
28
+ from .train.tuner import export_model, run_exp
29
+ from .webui.interface import run_web_demo, run_web_ui
30
+
31
+
32
+ USAGE = (
33
+ "-" * 70
34
+ + "\n"
35
+ + "| Usage: |\n"
36
+ + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
37
+ + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
38
+ + "| llamafactory-cli eval -h: evaluate models |\n"
39
+ + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
40
+ + "| llamafactory-cli train -h: train models |\n"
41
+ + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
42
+ + "| llamafactory-cli webui: launch LlamaBoard |\n"
43
+ + "| llamafactory-cli version: show version info |\n"
44
+ + "-" * 70
45
+ )
46
+
47
+ WELCOME = (
48
+ "-" * 58
49
+ + "\n"
50
+ + "| Welcome to LLaMA Factory, version {}".format(VERSION)
51
+ + " " * (21 - len(VERSION))
52
+ + "|\n|"
53
+ + " " * 56
54
+ + "|\n"
55
+ + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
56
+ + "-" * 58
57
+ )
58
+
59
+ logger = get_logger(__name__)
60
+
61
+
62
+ @unique
63
+ class Command(str, Enum):
64
+ API = "api"
65
+ CHAT = "chat"
66
+ ENV = "env"
67
+ EVAL = "eval"
68
+ EXPORT = "export"
69
+ TRAIN = "train"
70
+ WEBDEMO = "webchat"
71
+ WEBUI = "webui"
72
+ VER = "version"
73
+ HELP = "help"
74
+
75
+
76
+ def main():
77
+ command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
78
+ if command == Command.API:
79
+ run_api()
80
+ elif command == Command.CHAT:
81
+ run_chat()
82
+ elif command == Command.ENV:
83
+ print_env()
84
+ elif command == Command.EVAL:
85
+ run_eval()
86
+ elif command == Command.EXPORT:
87
+ export_model()
88
+ elif command == Command.TRAIN:
89
+ force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
90
+ if force_torchrun or get_device_count() > 1:
91
+ master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
92
+ master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
93
+ logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
94
+ process = subprocess.run(
95
+ (
96
+ "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
97
+ "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
98
+ ).format(
99
+ nnodes=os.environ.get("NNODES", "1"),
100
+ node_rank=os.environ.get("RANK", "0"),
101
+ nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
102
+ master_addr=master_addr,
103
+ master_port=master_port,
104
+ file_name=launcher.__file__,
105
+ args=" ".join(sys.argv[1:]),
106
+ ),
107
+ shell=True,
108
+ )
109
+ sys.exit(process.returncode)
110
+ else:
111
+ run_exp()
112
+ elif command == Command.WEBDEMO:
113
+ run_web_demo()
114
+ elif command == Command.WEBUI:
115
+ run_web_ui()
116
+ elif command == Command.VER:
117
+ print(WELCOME)
118
+ elif command == Command.HELP:
119
+ print(USAGE)
120
+ else:
121
+ raise NotImplementedError("Unknown command: {}".format(command))
llama-factory/src/llamafactory/data/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
16
+ from .data_utils import Role, split_dataset
17
+ from .loader import get_dataset
18
+ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
19
+
20
+
21
+ __all__ = [
22
+ "KTODataCollatorWithPadding",
23
+ "PairwiseDataCollatorWithPadding",
24
+ "SFTDataCollatorWith4DAttentionMask",
25
+ "Role",
26
+ "split_dataset",
27
+ "get_dataset",
28
+ "TEMPLATES",
29
+ "Template",
30
+ "get_template_and_fix_tokenizer",
31
+ ]
llama-factory/src/llamafactory/data/aligner.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from functools import partial
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
18
+
19
+ from datasets import Features
20
+
21
+ from ..extras.logging import get_logger
22
+ from .data_utils import Role
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from datasets import Dataset, IterableDataset
27
+ from transformers import Seq2SeqTrainingArguments
28
+
29
+ from ..hparams import DataArguments
30
+ from .parser import DatasetAttr
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
37
+ r"""
38
+ Optionally concatenates image path to dataset dir when loading from local disk.
39
+ """
40
+ outputs = []
41
+ if dataset_attr.load_from in ["script", "file"]:
42
+ for image in images:
43
+ if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
44
+ outputs.append(os.path.join(data_args.dataset_dir, image))
45
+ else:
46
+ outputs.append(image)
47
+
48
+ return outputs
49
+
50
+
51
+ def convert_alpaca(
52
+ examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
53
+ ) -> Dict[str, List[Any]]:
54
+ r"""
55
+ Converts alpaca format dataset to the standard format.
56
+ """
57
+ outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
58
+ convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
59
+ for i in range(len(examples[dataset_attr.prompt])):
60
+ prompt = []
61
+ if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
62
+ for old_prompt, old_response in examples[dataset_attr.history][i]:
63
+ prompt.append({"role": Role.USER.value, "content": old_prompt})
64
+ prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
65
+
66
+ content = []
67
+ if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
68
+ content.append(examples[dataset_attr.prompt][i])
69
+
70
+ if dataset_attr.query and examples[dataset_attr.query][i]:
71
+ content.append(examples[dataset_attr.query][i])
72
+
73
+ prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
74
+
75
+ if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
76
+ response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
77
+ if examples[dataset_attr.kto_tag][i]:
78
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
79
+ else:
80
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
81
+ elif (
82
+ dataset_attr.ranking
83
+ and isinstance(examples[dataset_attr.chosen][i], str)
84
+ and isinstance(examples[dataset_attr.rejected][i], str)
85
+ ): # pairwise example
86
+ response = [
87
+ {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
88
+ {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
89
+ ]
90
+ elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
91
+ response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
92
+ else: # unsupervised
93
+ response = []
94
+
95
+ outputs["prompt"].append(prompt)
96
+ outputs["response"].append(response)
97
+ outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
98
+ outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
99
+ outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
100
+
101
+ return outputs
102
+
103
+
104
+ def convert_sharegpt(
105
+ examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
106
+ ) -> Dict[str, List[Any]]:
107
+ r"""
108
+ Converts sharegpt format dataset to the standard format.
109
+ """
110
+ outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
111
+ convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
112
+ tag_mapping = {
113
+ dataset_attr.user_tag: Role.USER.value,
114
+ dataset_attr.assistant_tag: Role.ASSISTANT.value,
115
+ dataset_attr.observation_tag: Role.OBSERVATION.value,
116
+ dataset_attr.function_tag: Role.FUNCTION.value,
117
+ dataset_attr.system_tag: Role.SYSTEM.value,
118
+ }
119
+ odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
120
+ even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
121
+ accept_tags = (odd_tags, even_tags)
122
+ for i, messages in enumerate(examples[dataset_attr.messages]):
123
+ if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
124
+ system = messages[0][dataset_attr.content_tag]
125
+ messages = messages[1:]
126
+ else:
127
+ system = examples[dataset_attr.system][i] if dataset_attr.system else ""
128
+
129
+ if len(messages) == 0:
130
+ continue
131
+
132
+ aligned_messages = []
133
+ broken_data = False
134
+ for turn_idx, message in enumerate(messages):
135
+ if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
136
+ logger.warning("Invalid role tag in {}.".format(messages))
137
+ broken_data = True
138
+
139
+ aligned_messages.append(
140
+ {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
141
+ )
142
+
143
+ if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
144
+ dataset_attr.ranking and len(aligned_messages) % 2 == 0
145
+ ):
146
+ logger.warning("Invalid message count in {}.".format(messages))
147
+ broken_data = True
148
+
149
+ if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
150
+ prompt = aligned_messages[:-1]
151
+ response = aligned_messages[-1:]
152
+ if examples[dataset_attr.kto_tag][i]:
153
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
154
+ else:
155
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
156
+ elif (
157
+ dataset_attr.ranking
158
+ and isinstance(examples[dataset_attr.chosen][i], dict)
159
+ and isinstance(examples[dataset_attr.rejected][i], dict)
160
+ ): # pairwise example
161
+ chosen = examples[dataset_attr.chosen][i]
162
+ rejected = examples[dataset_attr.rejected][i]
163
+ if (
164
+ chosen[dataset_attr.role_tag] not in accept_tags[-1]
165
+ or rejected[dataset_attr.role_tag] not in accept_tags[-1]
166
+ ):
167
+ logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
168
+ broken_data = True
169
+
170
+ prompt = aligned_messages
171
+ response = [
172
+ {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
173
+ {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
174
+ ]
175
+ else: # normal example
176
+ prompt = aligned_messages[:-1]
177
+ response = aligned_messages[-1:]
178
+
179
+ if broken_data:
180
+ logger.warning("Skipping this abnormal example.")
181
+ continue
182
+
183
+ outputs["prompt"].append(prompt)
184
+ outputs["response"].append(response)
185
+ outputs["system"].append(system)
186
+ outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
187
+ outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
188
+
189
+ return outputs
190
+
191
+
192
+ def align_dataset(
193
+ dataset: Union["Dataset", "IterableDataset"],
194
+ dataset_attr: "DatasetAttr",
195
+ data_args: "DataArguments",
196
+ training_args: "Seq2SeqTrainingArguments",
197
+ ) -> Union["Dataset", "IterableDataset"]:
198
+ r"""
199
+ Aligned dataset:
200
+ prompt: [{"role": "user", "content": "..."}] * (2T - 1)
201
+ response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
202
+ system: "..."
203
+ tools: "...",
204
+ images: [],
205
+ """
206
+ if dataset_attr.formatting == "alpaca":
207
+ convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
208
+ else:
209
+ convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
210
+
211
+ column_names = list(next(iter(dataset)).keys())
212
+ features = Features.from_dict(
213
+ {
214
+ "prompt": [
215
+ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
216
+ ],
217
+ "response": [
218
+ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
219
+ ],
220
+ "system": {"dtype": "string", "_type": "Value"},
221
+ "tools": {"dtype": "string", "_type": "Value"},
222
+ "images": [{"_type": "Image"}],
223
+ }
224
+ )
225
+ kwargs = {}
226
+ if not data_args.streaming:
227
+ kwargs = dict(
228
+ num_proc=data_args.preprocessing_num_workers,
229
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
230
+ desc="Converting format of dataset",
231
+ )
232
+
233
+ return dataset.map(
234
+ convert_func,
235
+ batched=True,
236
+ remove_columns=column_names,
237
+ features=features,
238
+ **kwargs,
239
+ )
llama-factory/src/llamafactory/data/collator.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the OpenAccess AI Collective's axolotl library.
4
+ # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Literal, Sequence
20
+
21
+ import torch
22
+ from transformers import DataCollatorForSeq2Seq
23
+
24
+
25
+ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
26
+ r"""
27
+ Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
28
+ while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
29
+
30
+ e.g.
31
+ ```python
32
+ # input
33
+ [[1, 1, 2, 2, 2, 0]]
34
+ # output
35
+ [
36
+ [
37
+ [
38
+ [o, x, x, x, x, x],
39
+ [o, o, x, x, x, x],
40
+ [x, x, o, x, x, x],
41
+ [x, x, o, o, x, x],
42
+ [x, x, o, o, o, x],
43
+ [x, x, x, x, x, x],
44
+ ]
45
+ ]
46
+ ]
47
+ ```
48
+ where `o` equals to `0.0`, `x` equals to `min_dtype`.
49
+ """
50
+ bsz, seq_len = attention_mask_with_indices.size()
51
+ min_dtype = torch.finfo(dtype).min
52
+ expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
53
+ # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
54
+ padding_mask = torch.where(expanded_mask != 0, 1, 0)
55
+ # Create a block-diagonal mask.
56
+ attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
57
+ # Use the lower triangular mask to zero out the upper triangular part
58
+ attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
59
+ # Invert the attention mask.
60
+ attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
61
+ return attention_mask_4d
62
+
63
+
64
+ @dataclass
65
+ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
66
+ r"""
67
+ Data collator for 4d attention mask.
68
+ """
69
+
70
+ block_diag_attn: bool = False
71
+ attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
72
+ compute_dtype: "torch.dtype" = torch.float32
73
+
74
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
75
+ features = super().__call__(features)
76
+ if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
77
+ features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
78
+
79
+ return features
80
+
81
+
82
+ @dataclass
83
+ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
84
+ r"""
85
+ Data collator for pairwise data.
86
+ """
87
+
88
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
89
+ r"""
90
+ Pads batched data to the longest sequence in the batch.
91
+
92
+ We generate 2 * n examples where the first n examples represent chosen examples and
93
+ the last n examples represent rejected examples.
94
+ """
95
+ concatenated_features = []
96
+ for key in ("chosen", "rejected"):
97
+ for feature in features:
98
+ target_feature = {
99
+ "input_ids": feature["{}_input_ids".format(key)],
100
+ "attention_mask": feature["{}_attention_mask".format(key)],
101
+ "labels": feature["{}_labels".format(key)],
102
+ }
103
+ if "pixel_values" in feature:
104
+ target_feature["pixel_values"] = feature["pixel_values"]
105
+
106
+ if "{}_token_type_ids".format(key) in feature:
107
+ target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
108
+
109
+ concatenated_features.append(target_feature)
110
+
111
+ return super().__call__(concatenated_features)
112
+
113
+
114
+ @dataclass
115
+ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
116
+ r"""
117
+ Data collator for KTO data.
118
+ """
119
+
120
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
121
+ target_features = []
122
+ kl_features = []
123
+ kto_tags = []
124
+ for feature in features:
125
+ target_feature = {
126
+ "input_ids": feature["input_ids"],
127
+ "attention_mask": feature["attention_mask"],
128
+ "labels": feature["labels"],
129
+ }
130
+ kl_feature = {
131
+ "input_ids": feature["kl_input_ids"],
132
+ "attention_mask": feature["kl_attention_mask"],
133
+ "labels": feature["kl_labels"],
134
+ }
135
+ if "pixel_values" in feature:
136
+ target_feature["pixel_values"] = feature["pixel_values"]
137
+
138
+ if "token_type_ids" in feature:
139
+ target_feature["token_type_ids"] = feature["token_type_ids"]
140
+ kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
141
+
142
+ target_features.append(target_feature)
143
+ kl_features.append(kl_feature)
144
+ kto_tags.append(feature["kto_tags"])
145
+
146
+ batch = super().__call__(target_features)
147
+ kl_batch = super().__call__(kl_features)
148
+ batch["kl_input_ids"] = kl_batch["input_ids"]
149
+ batch["kl_attention_mask"] = kl_batch["attention_mask"]
150
+ batch["kl_labels"] = kl_batch["labels"]
151
+ if "token_type_ids" in batch:
152
+ batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
153
+
154
+ batch["kto_tags"] = torch.tensor(kto_tags)
155
+ return batch
llama-factory/src/llamafactory/data/data_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum, unique
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
17
+
18
+ from datasets import DatasetDict, concatenate_datasets, interleave_datasets
19
+
20
+ from ..extras.logging import get_logger
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from datasets import Dataset, IterableDataset
25
+
26
+ from ..hparams import DataArguments
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
33
+
34
+
35
+ @unique
36
+ class Role(str, Enum):
37
+ USER = "user"
38
+ ASSISTANT = "assistant"
39
+ SYSTEM = "system"
40
+ FUNCTION = "function"
41
+ OBSERVATION = "observation"
42
+
43
+
44
+ class DatasetModule(TypedDict):
45
+ train_dataset: Optional[Union["Dataset", "IterableDataset"]]
46
+ eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
47
+
48
+
49
+ def merge_dataset(
50
+ all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
51
+ ) -> Union["Dataset", "IterableDataset"]:
52
+ if len(all_datasets) == 1:
53
+ return all_datasets[0]
54
+ elif data_args.mix_strategy == "concat":
55
+ if data_args.streaming:
56
+ logger.warning("The samples between different datasets will not be mixed in streaming mode.")
57
+
58
+ return concatenate_datasets(all_datasets)
59
+ elif data_args.mix_strategy.startswith("interleave"):
60
+ if not data_args.streaming:
61
+ logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
62
+
63
+ return interleave_datasets(
64
+ datasets=all_datasets,
65
+ probabilities=data_args.interleave_probs,
66
+ seed=seed,
67
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
68
+ )
69
+ else:
70
+ raise ValueError("Unknown mixing strategy.")
71
+
72
+
73
+ def split_dataset(
74
+ dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
75
+ ) -> "DatasetDict":
76
+ r"""
77
+ Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
78
+ """
79
+ if data_args.streaming:
80
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
81
+ val_set = dataset.take(int(data_args.val_size))
82
+ train_set = dataset.skip(int(data_args.val_size))
83
+ return DatasetDict({"train": train_set, "validation": val_set})
84
+ else:
85
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
86
+ dataset = dataset.train_test_split(test_size=val_size, seed=seed)
87
+ return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
llama-factory/src/llamafactory/data/formatter.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass, field
19
+ from typing import List, Literal, Optional, Tuple, Union
20
+
21
+ from .data_utils import SLOTS
22
+ from .tool_utils import DefaultToolUtils, GLM4ToolUtils
23
+
24
+
25
+ @dataclass
26
+ class Formatter(ABC):
27
+ slots: SLOTS = field(default_factory=list)
28
+ tool_format: Optional[Literal["default", "glm4"]] = None
29
+
30
+ @abstractmethod
31
+ def apply(self, **kwargs) -> SLOTS: ...
32
+
33
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
34
+ raise NotImplementedError
35
+
36
+
37
+ @dataclass
38
+ class EmptyFormatter(Formatter):
39
+ def __post_init__(self):
40
+ has_placeholder = False
41
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
42
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
43
+ has_placeholder = True
44
+
45
+ if has_placeholder:
46
+ raise ValueError("Empty formatter should not contain any placeholder.")
47
+
48
+ def apply(self, **kwargs) -> SLOTS:
49
+ return self.slots
50
+
51
+
52
+ @dataclass
53
+ class StringFormatter(Formatter):
54
+ def __post_init__(self):
55
+ has_placeholder = False
56
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
57
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
58
+ has_placeholder = True
59
+
60
+ if not has_placeholder:
61
+ raise ValueError("A placeholder is required in the string formatter.")
62
+
63
+ def apply(self, **kwargs) -> SLOTS:
64
+ elements = []
65
+ for slot in self.slots:
66
+ if isinstance(slot, str):
67
+ for name, value in kwargs.items():
68
+ if not isinstance(value, str):
69
+ raise RuntimeError("Expected a string, got {}".format(value))
70
+
71
+ slot = slot.replace("{{" + name + "}}", value, 1)
72
+ elements.append(slot)
73
+ elif isinstance(slot, (dict, set)):
74
+ elements.append(slot)
75
+ else:
76
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
77
+
78
+ return elements
79
+
80
+
81
+ @dataclass
82
+ class FunctionFormatter(Formatter):
83
+ def __post_init__(self):
84
+ if self.tool_format == "default":
85
+ self.slots = DefaultToolUtils.get_function_slots() + self.slots
86
+ elif self.tool_format == "glm4":
87
+ self.slots = GLM4ToolUtils.get_function_slots() + self.slots
88
+ else:
89
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
90
+
91
+ def apply(self, **kwargs) -> SLOTS:
92
+ content = kwargs.pop("content")
93
+ functions: List[Tuple[str, str]] = []
94
+ try:
95
+ tool_calls = json.loads(content)
96
+ if not isinstance(tool_calls, list): # parallel function call
97
+ tool_calls = [tool_calls]
98
+
99
+ for tool_call in tool_calls:
100
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
101
+
102
+ except json.JSONDecodeError:
103
+ functions = []
104
+
105
+ elements = []
106
+ for name, arguments in functions:
107
+ for slot in self.slots:
108
+ if isinstance(slot, str):
109
+ slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
110
+ elements.append(slot)
111
+ elif isinstance(slot, (dict, set)):
112
+ elements.append(slot)
113
+ else:
114
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
115
+
116
+ return elements
117
+
118
+
119
+ @dataclass
120
+ class ToolFormatter(Formatter):
121
+ def __post_init__(self):
122
+ if self.tool_format == "default":
123
+ self._tool_formatter = DefaultToolUtils.tool_formatter
124
+ self._tool_extractor = DefaultToolUtils.tool_extractor
125
+ elif self.tool_format == "glm4":
126
+ self._tool_formatter = GLM4ToolUtils.tool_formatter
127
+ self._tool_extractor = GLM4ToolUtils.tool_extractor
128
+ else:
129
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
130
+
131
+ def apply(self, **kwargs) -> SLOTS:
132
+ content = kwargs.pop("content")
133
+ try:
134
+ tools = json.loads(content)
135
+ return [self._tool_formatter(tools) if len(tools) != 0 else ""]
136
+ except json.JSONDecodeError:
137
+ return [""]
138
+
139
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
140
+ return self._tool_extractor(content)
llama-factory/src/llamafactory/data/loader.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
18
+
19
+ import numpy as np
20
+ from datasets import DatasetDict, load_dataset, load_from_disk
21
+ from transformers.utils.versions import require_version
22
+
23
+ from ..extras.constants import FILEEXT2TYPE
24
+ from ..extras.logging import get_logger
25
+ from ..extras.misc import has_tokenized_data
26
+ from .aligner import align_dataset
27
+ from .data_utils import merge_dataset, split_dataset
28
+ from .parser import get_dataset_list
29
+ from .preprocess import get_preprocess_and_print_func
30
+ from .template import get_template_and_fix_tokenizer
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from datasets import Dataset, IterableDataset
35
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
36
+
37
+ from ..hparams import DataArguments, ModelArguments
38
+ from .data_utils import DatasetModule
39
+ from .parser import DatasetAttr
40
+ from .template import Template
41
+
42
+
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ def _load_single_dataset(
47
+ dataset_attr: "DatasetAttr",
48
+ model_args: "ModelArguments",
49
+ data_args: "DataArguments",
50
+ training_args: "Seq2SeqTrainingArguments",
51
+ ) -> Union["Dataset", "IterableDataset"]:
52
+ logger.info("Loading dataset {}...".format(dataset_attr))
53
+ data_path, data_name, data_dir, data_files = None, None, None, None
54
+ if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
55
+ data_path = dataset_attr.dataset_name
56
+ data_name = dataset_attr.subset
57
+ data_dir = dataset_attr.folder
58
+
59
+ elif dataset_attr.load_from == "script":
60
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
61
+ data_name = dataset_attr.subset
62
+ data_dir = dataset_attr.folder
63
+
64
+ elif dataset_attr.load_from == "file":
65
+ data_files = []
66
+ local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
67
+ if os.path.isdir(local_path): # is directory
68
+ for file_name in os.listdir(local_path):
69
+ data_files.append(os.path.join(local_path, file_name))
70
+ if data_path is None:
71
+ data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
72
+ elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
73
+ raise ValueError("File types should be identical.")
74
+ elif os.path.isfile(local_path): # is file
75
+ data_files.append(local_path)
76
+ data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
77
+ else:
78
+ raise ValueError("File {} not found.".format(local_path))
79
+
80
+ if data_path is None:
81
+ raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
82
+ else:
83
+ raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
84
+
85
+ if dataset_attr.load_from == "ms_hub":
86
+ require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
87
+ from modelscope import MsDataset
88
+ from modelscope.utils.config_ds import MS_DATASETS_CACHE
89
+
90
+ cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
91
+ dataset = MsDataset.load(
92
+ dataset_name=data_path,
93
+ subset_name=data_name,
94
+ data_dir=data_dir,
95
+ data_files=data_files,
96
+ split=dataset_attr.split,
97
+ cache_dir=cache_dir,
98
+ token=model_args.ms_hub_token,
99
+ use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
100
+ )
101
+ if isinstance(dataset, MsDataset):
102
+ dataset = dataset.to_hf_dataset()
103
+ else:
104
+ dataset = load_dataset(
105
+ path=data_path,
106
+ name=data_name,
107
+ data_dir=data_dir,
108
+ data_files=data_files,
109
+ split=dataset_attr.split,
110
+ cache_dir=model_args.cache_dir,
111
+ token=model_args.hf_hub_token,
112
+ streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
113
+ trust_remote_code=True,
114
+ )
115
+
116
+ if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
117
+ dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
118
+
119
+ if dataset_attr.num_samples is not None and not data_args.streaming:
120
+ target_num = dataset_attr.num_samples
121
+ indexes = np.random.permutation(len(dataset))[:target_num]
122
+ target_num -= len(indexes)
123
+ if target_num > 0:
124
+ expand_indexes = np.random.choice(len(dataset), target_num)
125
+ indexes = np.concatenate((indexes, expand_indexes), axis=0)
126
+
127
+ assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
128
+ dataset = dataset.select(indexes)
129
+ logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
130
+
131
+ if data_args.max_samples is not None: # truncate dataset
132
+ max_samples = min(data_args.max_samples, len(dataset))
133
+ dataset = dataset.select(range(max_samples))
134
+
135
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
136
+
137
+
138
+ def _get_merged_dataset(
139
+ dataset_names: Optional[Sequence[str]],
140
+ model_args: "ModelArguments",
141
+ data_args: "DataArguments",
142
+ training_args: "Seq2SeqTrainingArguments",
143
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
144
+ ) -> Optional[Union["Dataset", "IterableDataset"]]:
145
+ if dataset_names is None:
146
+ return None
147
+
148
+ datasets = []
149
+ for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
150
+ if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
151
+ raise ValueError("The dataset is not applicable in the current training stage.")
152
+
153
+ datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
154
+
155
+ return merge_dataset(datasets, data_args, seed=training_args.seed)
156
+
157
+
158
+ def _get_preprocessed_dataset(
159
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
160
+ data_args: "DataArguments",
161
+ training_args: "Seq2SeqTrainingArguments",
162
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
163
+ template: "Template",
164
+ tokenizer: "PreTrainedTokenizer",
165
+ processor: Optional["ProcessorMixin"] = None,
166
+ is_eval: bool = False,
167
+ ) -> Optional[Union["Dataset", "IterableDataset"]]:
168
+ if dataset is None:
169
+ return None
170
+
171
+ preprocess_func, print_function = get_preprocess_and_print_func(
172
+ data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
173
+ )
174
+ column_names = list(next(iter(dataset)).keys())
175
+ kwargs = {}
176
+ if not data_args.streaming:
177
+ kwargs = dict(
178
+ num_proc=data_args.preprocessing_num_workers,
179
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
180
+ desc="Running tokenizer on dataset",
181
+ )
182
+
183
+ dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
184
+
185
+ if training_args.should_log:
186
+ try:
187
+ print("eval example:" if is_eval else "training example:")
188
+ print_function(next(iter(dataset)))
189
+ except StopIteration:
190
+ if stage == "pt":
191
+ raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
192
+ else:
193
+ raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
194
+
195
+ return dataset
196
+
197
+
198
+ def get_dataset(
199
+ model_args: "ModelArguments",
200
+ data_args: "DataArguments",
201
+ training_args: "Seq2SeqTrainingArguments",
202
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
203
+ tokenizer: "PreTrainedTokenizer",
204
+ processor: Optional["ProcessorMixin"] = None,
205
+ ) -> "DatasetModule":
206
+ template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
207
+ if data_args.train_on_prompt and template.efficient_eos:
208
+ raise ValueError("Current template does not support `train_on_prompt`.")
209
+
210
+ # Load tokenized dataset
211
+ if data_args.tokenized_path is not None:
212
+ if has_tokenized_data(data_args.tokenized_path):
213
+ logger.warning("Loading dataset from disk will ignore other data arguments.")
214
+ dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
215
+ logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
216
+
217
+ dataset_module: Dict[str, "Dataset"] = {}
218
+ if "train" in dataset_dict:
219
+ dataset_module["train_dataset"] = dataset_dict["train"]
220
+ if "validation" in dataset_dict:
221
+ dataset_module["eval_dataset"] = dataset_dict["validation"]
222
+
223
+ if data_args.streaming:
224
+ dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
225
+
226
+ return dataset_module
227
+
228
+ if data_args.streaming:
229
+ raise ValueError("Turn off `streaming` when saving dataset to disk.")
230
+
231
+ # Load and preprocess dataset
232
+ with training_args.main_process_first(desc="load dataset"):
233
+ dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
234
+ eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
235
+
236
+ with training_args.main_process_first(desc="pre-process dataset"):
237
+ dataset = _get_preprocessed_dataset(
238
+ dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
239
+ )
240
+ eval_dataset = _get_preprocessed_dataset(
241
+ eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
242
+ )
243
+
244
+ if data_args.val_size > 1e-6:
245
+ dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
246
+ else:
247
+ dataset_dict = {}
248
+ if dataset is not None:
249
+ if data_args.streaming:
250
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
251
+
252
+ dataset_dict["train"] = dataset
253
+
254
+ if eval_dataset is not None:
255
+ if data_args.streaming:
256
+ eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
257
+
258
+ dataset_dict["validation"] = eval_dataset
259
+
260
+ dataset_dict = DatasetDict(dataset_dict)
261
+
262
+ if data_args.tokenized_path is not None:
263
+ if training_args.should_save:
264
+ dataset_dict.save_to_disk(data_args.tokenized_path)
265
+ logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
266
+ logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
267
+
268
+ sys.exit(0)
269
+
270
+ dataset_module = {}
271
+ if "train" in dataset_dict:
272
+ dataset_module["train_dataset"] = dataset_dict["train"]
273
+ if "validation" in dataset_dict:
274
+ dataset_module["eval_dataset"] = dataset_dict["validation"]
275
+
276
+ return dataset_module
llama-factory/src/llamafactory/data/parser.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Literal, Optional, Sequence
19
+
20
+ from transformers.utils import cached_file
21
+
22
+ from ..extras.constants import DATA_CONFIG
23
+ from ..extras.misc import use_modelscope
24
+
25
+
26
+ @dataclass
27
+ class DatasetAttr:
28
+ r"""
29
+ Dataset attributes.
30
+ """
31
+
32
+ # basic configs
33
+ load_from: Literal["hf_hub", "ms_hub", "script", "file"]
34
+ dataset_name: str
35
+ formatting: Literal["alpaca", "sharegpt"] = "alpaca"
36
+ ranking: bool = False
37
+ # extra configs
38
+ subset: Optional[str] = None
39
+ split: str = "train"
40
+ folder: Optional[str] = None
41
+ num_samples: Optional[int] = None
42
+ # common columns
43
+ system: Optional[str] = None
44
+ tools: Optional[str] = None
45
+ images: Optional[str] = None
46
+ # rlhf columns
47
+ chosen: Optional[str] = None
48
+ rejected: Optional[str] = None
49
+ kto_tag: Optional[str] = None
50
+ # alpaca columns
51
+ prompt: Optional[str] = "instruction"
52
+ query: Optional[str] = "input"
53
+ response: Optional[str] = "output"
54
+ history: Optional[str] = None
55
+ # sharegpt columns
56
+ messages: Optional[str] = "conversations"
57
+ # sharegpt tags
58
+ role_tag: Optional[str] = "from"
59
+ content_tag: Optional[str] = "value"
60
+ user_tag: Optional[str] = "human"
61
+ assistant_tag: Optional[str] = "gpt"
62
+ observation_tag: Optional[str] = "observation"
63
+ function_tag: Optional[str] = "function_call"
64
+ system_tag: Optional[str] = "system"
65
+
66
+ def __repr__(self) -> str:
67
+ return self.dataset_name
68
+
69
+ def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
70
+ setattr(self, key, obj.get(key, default))
71
+
72
+
73
+ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
74
+ r"""
75
+ Gets the attributes of the datasets.
76
+ """
77
+ if dataset_names is None:
78
+ dataset_names = []
79
+
80
+ if dataset_dir == "ONLINE":
81
+ dataset_info = None
82
+ else:
83
+ if dataset_dir.startswith("REMOTE:"):
84
+ config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
85
+ else:
86
+ config_path = os.path.join(dataset_dir, DATA_CONFIG)
87
+
88
+ try:
89
+ with open(config_path, "r") as f:
90
+ dataset_info = json.load(f)
91
+ except Exception as err:
92
+ if len(dataset_names) != 0:
93
+ raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
94
+
95
+ dataset_info = None
96
+
97
+ dataset_list: List["DatasetAttr"] = []
98
+ for name in dataset_names:
99
+ if dataset_info is None: # dataset_dir is ONLINE
100
+ load_from = "ms_hub" if use_modelscope() else "hf_hub"
101
+ dataset_attr = DatasetAttr(load_from, dataset_name=name)
102
+ dataset_list.append(dataset_attr)
103
+ continue
104
+
105
+ if name not in dataset_info:
106
+ raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
107
+
108
+ has_hf_url = "hf_hub_url" in dataset_info[name]
109
+ has_ms_url = "ms_hub_url" in dataset_info[name]
110
+
111
+ if has_hf_url or has_ms_url:
112
+ if (use_modelscope() and has_ms_url) or (not has_hf_url):
113
+ dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
114
+ else:
115
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
116
+ elif "script_url" in dataset_info[name]:
117
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
118
+ else:
119
+ dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
120
+
121
+ dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
122
+ dataset_attr.set_attr("ranking", dataset_info[name], default=False)
123
+ dataset_attr.set_attr("subset", dataset_info[name])
124
+ dataset_attr.set_attr("split", dataset_info[name], default="train")
125
+ dataset_attr.set_attr("folder", dataset_info[name])
126
+ dataset_attr.set_attr("num_samples", dataset_info[name])
127
+
128
+ if "columns" in dataset_info[name]:
129
+ column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
130
+ if dataset_attr.formatting == "alpaca":
131
+ column_names.extend(["prompt", "query", "response", "history"])
132
+ else:
133
+ column_names.extend(["messages"])
134
+
135
+ for column_name in column_names:
136
+ dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
137
+
138
+ if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
139
+ tag_names = (
140
+ "role_tag",
141
+ "content_tag",
142
+ "user_tag",
143
+ "assistant_tag",
144
+ "observation_tag",
145
+ "function_tag",
146
+ "system_tag",
147
+ )
148
+ for tag in tag_names:
149
+ dataset_attr.set_attr(tag, dataset_info[name]["tags"])
150
+
151
+ dataset_list.append(dataset_attr)
152
+
153
+ return dataset_list
llama-factory/src/llamafactory/data/preprocess.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
17
+
18
+ from .processors.feedback import preprocess_feedback_dataset
19
+ from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
20
+ from .processors.pretrain import preprocess_pretrain_dataset
21
+ from .processors.supervised import (
22
+ preprocess_packed_supervised_dataset,
23
+ preprocess_supervised_dataset,
24
+ print_supervised_dataset_example,
25
+ )
26
+ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from transformers import PreTrainedTokenizer, ProcessorMixin
31
+
32
+ from ..hparams import DataArguments
33
+ from .template import Template
34
+
35
+
36
+ def get_preprocess_and_print_func(
37
+ data_args: "DataArguments",
38
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
39
+ template: "Template",
40
+ tokenizer: "PreTrainedTokenizer",
41
+ processor: Optional["ProcessorMixin"],
42
+ do_generate: bool = False,
43
+ ) -> Tuple[Callable, Callable]:
44
+ if stage == "pt":
45
+ preprocess_func = partial(
46
+ preprocess_pretrain_dataset,
47
+ tokenizer=tokenizer,
48
+ data_args=data_args,
49
+ )
50
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
51
+ elif stage == "sft" and not do_generate:
52
+ if data_args.packing:
53
+ if data_args.neat_packing:
54
+ from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
55
+
56
+ def __init__(self, data, **kwargs):
57
+ return TypedSequence.__init__(
58
+ self,
59
+ data,
60
+ type=kwargs.pop("type", None),
61
+ try_type=kwargs.pop("try_type", None),
62
+ optimized_int_type=kwargs.pop("optimized_int_type", None),
63
+ )
64
+
65
+ OptimizedTypedSequence.__init__ = __init__
66
+ preprocess_func = partial(
67
+ preprocess_packed_supervised_dataset,
68
+ template=template,
69
+ tokenizer=tokenizer,
70
+ data_args=data_args,
71
+ )
72
+ else:
73
+ preprocess_func = partial(
74
+ preprocess_supervised_dataset,
75
+ template=template,
76
+ tokenizer=tokenizer,
77
+ processor=processor,
78
+ data_args=data_args,
79
+ )
80
+
81
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
82
+ elif stage == "rm":
83
+ preprocess_func = partial(
84
+ preprocess_pairwise_dataset,
85
+ template=template,
86
+ tokenizer=tokenizer,
87
+ processor=processor,
88
+ data_args=data_args,
89
+ )
90
+ print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
91
+ elif stage == "kto":
92
+ preprocess_func = partial(
93
+ preprocess_feedback_dataset,
94
+ template=template,
95
+ tokenizer=tokenizer,
96
+ processor=processor,
97
+ data_args=data_args,
98
+ )
99
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
100
+ else:
101
+ preprocess_func = partial(
102
+ preprocess_unsupervised_dataset,
103
+ template=template,
104
+ tokenizer=tokenizer,
105
+ processor=processor,
106
+ data_args=data_args,
107
+ )
108
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
109
+
110
+ return preprocess_func, print_function
llama-factory/src/llamafactory/data/processors/__init__.py ADDED
File without changes
llama-factory/src/llamafactory/data/processors/feedback.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
16
+
17
+ from ...extras.constants import IGNORE_INDEX
18
+ from ...extras.logging import get_logger
19
+ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers import PreTrainedTokenizer, ProcessorMixin
24
+
25
+ from ...hparams import DataArguments
26
+ from ..template import Template
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def _encode_feedback_example(
33
+ prompt: Sequence[Dict[str, str]],
34
+ response: Sequence[Dict[str, str]],
35
+ kl_response: Sequence[Dict[str, str]],
36
+ system: Optional[str],
37
+ tools: Optional[str],
38
+ template: "Template",
39
+ tokenizer: "PreTrainedTokenizer",
40
+ processor: Optional["ProcessorMixin"],
41
+ data_args: "DataArguments",
42
+ ) -> Tuple[List[int], List[int], List[int], List[int], bool]:
43
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
44
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
45
+
46
+ if response[0]["content"]: # desired example
47
+ kto_tag = True
48
+ messages = prompt + [response[0]]
49
+ else: # undesired example
50
+ kto_tag = False
51
+ messages = prompt + [response[1]]
52
+
53
+ if kl_response[0]["content"]:
54
+ kl_messages = prompt + [kl_response[0]]
55
+ else:
56
+ kl_messages = prompt + [kl_response[1]]
57
+
58
+ prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
59
+ kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
60
+
61
+ if template.efficient_eos:
62
+ response_ids += [tokenizer.eos_token_id]
63
+ kl_response_ids += [tokenizer.eos_token_id]
64
+
65
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
66
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
67
+ prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
68
+ kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
69
+
70
+ source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
71
+ prompt_ids = prompt_ids[:source_len]
72
+ response_ids = response_ids[:target_len]
73
+ kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
74
+ kl_prompt_ids = kl_prompt_ids[:kl_source_len]
75
+ kl_response_ids = kl_response_ids[:kl_target_len]
76
+
77
+ input_ids = prompt_ids + response_ids
78
+ labels = [IGNORE_INDEX] * source_len + response_ids
79
+ kl_input_ids = kl_prompt_ids + kl_response_ids
80
+ kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
81
+
82
+ return input_ids, labels, kl_input_ids, kl_labels, kto_tag
83
+
84
+
85
+ def preprocess_feedback_dataset(
86
+ examples: Dict[str, List[Any]],
87
+ template: "Template",
88
+ tokenizer: "PreTrainedTokenizer",
89
+ processor: Optional["ProcessorMixin"],
90
+ data_args: "DataArguments",
91
+ ) -> Dict[str, List[List[int]]]:
92
+ # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
93
+ kl_response = examples["response"][::-1]
94
+ model_inputs = {
95
+ "input_ids": [],
96
+ "attention_mask": [],
97
+ "labels": [],
98
+ "kl_input_ids": [],
99
+ "kl_attention_mask": [],
100
+ "kl_labels": [],
101
+ "kto_tags": [],
102
+ }
103
+ if processor is not None:
104
+ model_inputs["pixel_values"] = []
105
+ if hasattr(processor, "image_seq_length"): # paligemma models
106
+ model_inputs["token_type_ids"] = []
107
+ model_inputs["kl_token_type_ids"] = []
108
+
109
+ for i in range(len(examples["prompt"])):
110
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
111
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
112
+ continue
113
+
114
+ input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
115
+ prompt=examples["prompt"][i],
116
+ response=examples["response"][i],
117
+ kl_response=kl_response[i],
118
+ system=examples["system"][i],
119
+ tools=examples["tools"][i],
120
+ template=template,
121
+ tokenizer=tokenizer,
122
+ processor=processor,
123
+ data_args=data_args,
124
+ )
125
+ model_inputs["input_ids"].append(input_ids)
126
+ model_inputs["attention_mask"].append([1] * len(input_ids))
127
+ model_inputs["labels"].append(labels)
128
+ model_inputs["kl_input_ids"].append(kl_input_ids)
129
+ model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
130
+ model_inputs["kl_labels"].append(kl_labels)
131
+ model_inputs["kto_tags"].append(kto_tag)
132
+ if processor is not None:
133
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
134
+ if hasattr(processor, "image_seq_length"): # paligemma models
135
+ model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
136
+ model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
137
+
138
+ desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
139
+ undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
140
+ if desirable_num == 0 or undesirable_num == 0:
141
+ logger.warning("Your dataset only has one preference type.")
142
+
143
+ return model_inputs
llama-factory/src/llamafactory/data/processors/pairwise.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
16
+
17
+ from ...extras.constants import IGNORE_INDEX
18
+ from ...extras.logging import get_logger
19
+ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers import PreTrainedTokenizer, ProcessorMixin
24
+
25
+ from ...hparams import DataArguments
26
+ from ..template import Template
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def _encode_pairwise_example(
33
+ prompt: Sequence[Dict[str, str]],
34
+ response: Sequence[Dict[str, str]],
35
+ system: Optional[str],
36
+ tools: Optional[str],
37
+ template: "Template",
38
+ tokenizer: "PreTrainedTokenizer",
39
+ processor: Optional["ProcessorMixin"],
40
+ data_args: "DataArguments",
41
+ ) -> Tuple[List[int], List[int], List[int], List[int]]:
42
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
43
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
44
+
45
+ chosen_messages = prompt + [response[0]]
46
+ rejected_messages = prompt + [response[1]]
47
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
48
+ _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
49
+
50
+ if template.efficient_eos:
51
+ chosen_ids += [tokenizer.eos_token_id]
52
+ rejected_ids += [tokenizer.eos_token_id]
53
+
54
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
55
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
56
+ prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
57
+
58
+ source_len, target_len = infer_seqlen(
59
+ len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
60
+ ) # consider the response is more important
61
+ prompt_ids = prompt_ids[:source_len]
62
+ chosen_ids = chosen_ids[:target_len]
63
+ rejected_ids = rejected_ids[:target_len]
64
+
65
+ chosen_input_ids = prompt_ids + chosen_ids
66
+ chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
67
+ rejected_input_ids = prompt_ids + rejected_ids
68
+ rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
69
+
70
+ return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
71
+
72
+
73
+ def preprocess_pairwise_dataset(
74
+ examples: Dict[str, List[Any]],
75
+ template: "Template",
76
+ tokenizer: "PreTrainedTokenizer",
77
+ processor: Optional["ProcessorMixin"],
78
+ data_args: "DataArguments",
79
+ ) -> Dict[str, List[List[int]]]:
80
+ # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
81
+ model_inputs = {
82
+ "chosen_input_ids": [],
83
+ "chosen_attention_mask": [],
84
+ "chosen_labels": [],
85
+ "rejected_input_ids": [],
86
+ "rejected_attention_mask": [],
87
+ "rejected_labels": [],
88
+ }
89
+ if processor is not None:
90
+ model_inputs["pixel_values"] = []
91
+ if hasattr(processor, "image_seq_length"): # paligemma models
92
+ model_inputs["chosen_token_type_ids"] = []
93
+ model_inputs["rejected_token_type_ids"] = []
94
+
95
+ for i in range(len(examples["prompt"])):
96
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
97
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
98
+ continue
99
+
100
+ chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
101
+ prompt=examples["prompt"][i],
102
+ response=examples["response"][i],
103
+ system=examples["system"][i],
104
+ tools=examples["tools"][i],
105
+ template=template,
106
+ tokenizer=tokenizer,
107
+ processor=processor,
108
+ data_args=data_args,
109
+ )
110
+ model_inputs["chosen_input_ids"].append(chosen_input_ids)
111
+ model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
112
+ model_inputs["chosen_labels"].append(chosen_labels)
113
+ model_inputs["rejected_input_ids"].append(rejected_input_ids)
114
+ model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
115
+ model_inputs["rejected_labels"].append(rejected_labels)
116
+ if processor is not None:
117
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
118
+ if hasattr(processor, "image_seq_length"): # paligemma models
119
+ model_inputs["chosen_token_type_ids"].append(
120
+ get_paligemma_token_type_ids(len(chosen_input_ids), processor)
121
+ )
122
+ model_inputs["rejected_token_type_ids"].append(
123
+ get_paligemma_token_type_ids(len(rejected_input_ids), processor)
124
+ )
125
+
126
+ return model_inputs
127
+
128
+
129
+ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
130
+ valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
131
+ valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
132
+ print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
133
+ print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
134
+ print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
135
+ print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
136
+ print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
137
+ print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
138
+ print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
139
+ print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
llama-factory/src/llamafactory/data/processors/pretrain.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from itertools import chain
19
+ from typing import TYPE_CHECKING, Any, Dict, List
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers import PreTrainedTokenizer
24
+
25
+ from ...hparams import DataArguments
26
+
27
+
28
+ def preprocess_pretrain_dataset(
29
+ examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
30
+ ) -> Dict[str, List[List[int]]]:
31
+ # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
32
+ eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
33
+ text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
34
+
35
+ if not data_args.packing:
36
+ if data_args.template == "gemma":
37
+ text_examples = [tokenizer.bos_token + example for example in text_examples]
38
+
39
+ result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
40
+ else:
41
+ tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
42
+ concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
43
+ total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
44
+ block_size = data_args.cutoff_len
45
+ total_length = (total_length // block_size) * block_size
46
+ result = {
47
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
48
+ for k, t in concatenated_examples.items()
49
+ }
50
+ if data_args.template == "gemma":
51
+ for i in range(len(result["input_ids"])):
52
+ result["input_ids"][i][0] = tokenizer.bos_token_id
53
+
54
+ return result
llama-factory/src/llamafactory/data/processors/processor_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import bisect
16
+ from typing import TYPE_CHECKING, List, Sequence, Tuple
17
+
18
+ from ...extras.packages import is_pillow_available
19
+
20
+
21
+ if is_pillow_available():
22
+ from PIL import Image
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from numpy.typing import NDArray
27
+ from PIL.Image import Image as ImageObject
28
+ from transformers import ProcessorMixin
29
+ from transformers.image_processing_utils import BaseImageProcessor
30
+
31
+
32
+ def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
33
+ r"""
34
+ Finds the index of largest number that fits into the knapsack with the given capacity.
35
+ """
36
+ index = bisect.bisect(numbers, capacity)
37
+ return -1 if index == 0 else (index - 1)
38
+
39
+
40
+ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
41
+ r"""
42
+ An efficient greedy algorithm with binary search for the knapsack problem.
43
+ """
44
+ numbers.sort() # sort numbers in ascending order for binary search
45
+ knapsacks = []
46
+
47
+ while numbers:
48
+ current_knapsack = []
49
+ remaining_capacity = capacity
50
+
51
+ while True:
52
+ index = search_for_fit(numbers, remaining_capacity)
53
+ if index == -1:
54
+ break # no more numbers fit in this knapsack
55
+
56
+ remaining_capacity -= numbers[index] # update the remaining capacity
57
+ current_knapsack.append(numbers.pop(index)) # add the number to knapsack
58
+
59
+ knapsacks.append(current_knapsack)
60
+
61
+ return knapsacks
62
+
63
+
64
+ def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
65
+ r"""
66
+ Processes visual inputs. (currently only supports a single image)
67
+ """
68
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
69
+ image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
70
+ return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
71
+
72
+
73
+ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
74
+ r"""
75
+ Gets paligemma token type ids for computing loss.
76
+ """
77
+ image_seq_length = getattr(processor, "image_seq_length")
78
+ return [0] * image_seq_length + [1] * (input_len - image_seq_length)
79
+
80
+
81
+ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
82
+ r"""
83
+ Computes the real sequence length after truncation by the cutoff_len.
84
+ """
85
+ if target_len * 2 < cutoff_len: # truncate source
86
+ max_target_len = cutoff_len
87
+ elif source_len * 2 < cutoff_len: # truncate target
88
+ max_target_len = cutoff_len - source_len
89
+ else: # truncate both
90
+ max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
91
+
92
+ new_target_len = min(max_target_len, target_len)
93
+ max_source_len = max(cutoff_len - new_target_len, 0)
94
+ new_source_len = min(max_source_len, source_len)
95
+ return new_source_len, new_target_len
llama-factory/src/llamafactory/data/processors/supervised.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
17
+
18
+ from ...extras.constants import IGNORE_INDEX
19
+ from ...extras.logging import get_logger
20
+ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers import PreTrainedTokenizer, ProcessorMixin
25
+
26
+ from ...hparams import DataArguments
27
+ from ..template import Template
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ def _encode_supervised_example(
34
+ prompt: Sequence[Dict[str, str]],
35
+ response: Sequence[Dict[str, str]],
36
+ system: Optional[str],
37
+ tools: Optional[str],
38
+ template: "Template",
39
+ tokenizer: "PreTrainedTokenizer",
40
+ processor: Optional["ProcessorMixin"],
41
+ data_args: "DataArguments",
42
+ ) -> Tuple[List[int], List[int]]:
43
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
44
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
45
+
46
+ messages = prompt + response
47
+ input_ids, labels = [], []
48
+
49
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
50
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
51
+ input_ids += [image_token_id] * getattr(processor, "image_seq_length")
52
+ labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
53
+
54
+ encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
55
+ total_length = 1 if template.efficient_eos else 0
56
+ for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
57
+ if total_length >= data_args.cutoff_len:
58
+ break
59
+
60
+ source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
61
+ source_ids = source_ids[:source_len]
62
+ target_ids = target_ids[:target_len]
63
+ total_length += source_len + target_len
64
+
65
+ if data_args.train_on_prompt:
66
+ source_label = source_ids
67
+ elif turn_idx != 0 and template.efficient_eos:
68
+ source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
69
+ else:
70
+ source_label = [IGNORE_INDEX] * source_len
71
+
72
+ if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
73
+ target_label = [IGNORE_INDEX] * target_len
74
+ else:
75
+ target_label = target_ids
76
+
77
+ input_ids += source_ids + target_ids
78
+ labels += source_label + target_label
79
+
80
+ if template.efficient_eos:
81
+ input_ids += [tokenizer.eos_token_id]
82
+ labels += [tokenizer.eos_token_id]
83
+
84
+ return input_ids, labels
85
+
86
+
87
+ def preprocess_supervised_dataset(
88
+ examples: Dict[str, List[Any]],
89
+ template: "Template",
90
+ tokenizer: "PreTrainedTokenizer",
91
+ processor: Optional["ProcessorMixin"],
92
+ data_args: "DataArguments",
93
+ ) -> Dict[str, List[List[int]]]:
94
+ # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
95
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
96
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
97
+ if processor is not None:
98
+ model_inputs["pixel_values"] = []
99
+ if hasattr(processor, "image_seq_length"): # paligemma models
100
+ model_inputs["token_type_ids"] = []
101
+
102
+ for i in range(len(examples["prompt"])):
103
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
104
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
105
+ continue
106
+
107
+ input_ids, labels = _encode_supervised_example(
108
+ prompt=examples["prompt"][i],
109
+ response=examples["response"][i],
110
+ system=examples["system"][i],
111
+ tools=examples["tools"][i],
112
+ template=template,
113
+ tokenizer=tokenizer,
114
+ processor=processor,
115
+ data_args=data_args,
116
+ )
117
+ model_inputs["input_ids"].append(input_ids)
118
+ model_inputs["attention_mask"].append([1] * len(input_ids))
119
+ model_inputs["labels"].append(labels)
120
+ if processor is not None:
121
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
122
+ if hasattr(processor, "image_seq_length"): # paligemma models
123
+ model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
124
+
125
+ return model_inputs
126
+
127
+
128
+ def preprocess_packed_supervised_dataset(
129
+ examples: Dict[str, List[Any]],
130
+ template: "Template",
131
+ tokenizer: "PreTrainedTokenizer",
132
+ data_args: "DataArguments",
133
+ ) -> Dict[str, List[List[int]]]:
134
+ # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
135
+ # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
136
+ valid_num = 0
137
+ batch_input_ids, batch_labels = [], []
138
+ lengths = []
139
+ length2indexes = defaultdict(list)
140
+ for i in range(len(examples["prompt"])):
141
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
142
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
143
+ continue
144
+
145
+ input_ids, labels = _encode_supervised_example(
146
+ prompt=examples["prompt"][i],
147
+ response=examples["response"][i],
148
+ system=examples["system"][i],
149
+ tools=examples["tools"][i],
150
+ template=template,
151
+ tokenizer=tokenizer,
152
+ processor=None,
153
+ data_args=data_args,
154
+ )
155
+ length = len(input_ids)
156
+ if length > data_args.cutoff_len:
157
+ logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
158
+ else:
159
+ lengths.append(length)
160
+ length2indexes[length].append(valid_num)
161
+ batch_input_ids.append(input_ids)
162
+ batch_labels.append(labels)
163
+ valid_num += 1
164
+
165
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
166
+ knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
167
+ for knapsack in knapsacks:
168
+ packed_input_ids, packed_attention_masks, packed_labels = [], [], []
169
+ for i, length in enumerate(knapsack):
170
+ index = length2indexes[length].pop()
171
+ packed_input_ids += batch_input_ids[index]
172
+ packed_labels += batch_labels[index]
173
+ if data_args.neat_packing:
174
+ packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
175
+ else:
176
+ packed_attention_masks += [1] * len(batch_input_ids[index])
177
+
178
+ if len(packed_input_ids) < data_args.cutoff_len:
179
+ pad_length = data_args.cutoff_len - len(packed_input_ids)
180
+ packed_input_ids += [tokenizer.pad_token_id] * pad_length
181
+ packed_labels += [IGNORE_INDEX] * pad_length
182
+ if data_args.neat_packing:
183
+ packed_attention_masks += [0] * pad_length
184
+ else:
185
+ packed_attention_masks += [1] * pad_length # more efficient flash_attn
186
+
187
+ if len(packed_input_ids) != data_args.cutoff_len:
188
+ raise ValueError("The length of packed example should be identical to the cutoff length.")
189
+
190
+ model_inputs["input_ids"].append(packed_input_ids)
191
+ model_inputs["attention_mask"].append(packed_attention_masks)
192
+ model_inputs["labels"].append(packed_labels)
193
+
194
+ return model_inputs
195
+
196
+
197
+ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
198
+ valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
199
+ print("input_ids:\n{}".format(example["input_ids"]))
200
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
201
+ print("label_ids:\n{}".format(example["labels"]))
202
+ print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
llama-factory/src/llamafactory/data/processors/unsupervised.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
16
+
17
+ from ...extras.logging import get_logger
18
+ from ..data_utils import Role
19
+ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers import PreTrainedTokenizer, ProcessorMixin
24
+
25
+ from ...hparams import DataArguments
26
+ from ..template import Template
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def _encode_unsupervised_example(
33
+ prompt: Sequence[Dict[str, str]],
34
+ response: Sequence[Dict[str, str]],
35
+ system: Optional[str],
36
+ tools: Optional[str],
37
+ template: "Template",
38
+ tokenizer: "PreTrainedTokenizer",
39
+ processor: Optional["ProcessorMixin"],
40
+ data_args: "DataArguments",
41
+ ) -> Tuple[List[int], List[int]]:
42
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
43
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
44
+
45
+ if len(response) == 1:
46
+ messages = prompt + response
47
+ else:
48
+ messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
49
+
50
+ input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
51
+ if template.efficient_eos:
52
+ labels += [tokenizer.eos_token_id]
53
+
54
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
55
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
56
+ input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
57
+
58
+ source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
59
+ input_ids = input_ids[:source_len]
60
+ labels = labels[:target_len]
61
+ return input_ids, labels
62
+
63
+
64
+ def preprocess_unsupervised_dataset(
65
+ examples: Dict[str, List[Any]],
66
+ template: "Template",
67
+ tokenizer: "PreTrainedTokenizer",
68
+ processor: Optional["ProcessorMixin"],
69
+ data_args: "DataArguments",
70
+ ) -> Dict[str, List[List[int]]]:
71
+ # build inputs with format `<bos> X` and labels with format `Y <eos>`
72
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
73
+ if processor is not None:
74
+ model_inputs["pixel_values"] = []
75
+ if hasattr(processor, "image_seq_length"): # paligemma models
76
+ model_inputs["token_type_ids"] = []
77
+
78
+ for i in range(len(examples["prompt"])):
79
+ if len(examples["prompt"][i]) % 2 != 1:
80
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
81
+ continue
82
+
83
+ input_ids, labels = _encode_unsupervised_example(
84
+ prompt=examples["prompt"][i],
85
+ response=examples["response"][i],
86
+ system=examples["system"][i],
87
+ tools=examples["tools"][i],
88
+ template=template,
89
+ tokenizer=tokenizer,
90
+ processor=processor,
91
+ data_args=data_args,
92
+ )
93
+ model_inputs["input_ids"].append(input_ids)
94
+ model_inputs["attention_mask"].append([1] * len(input_ids))
95
+ model_inputs["labels"].append(labels)
96
+ if processor is not None:
97
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
98
+ if hasattr(processor, "image_seq_length"): # paligemma models
99
+ model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
100
+
101
+ return model_inputs
102
+
103
+
104
+ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
105
+ print("input_ids:\n{}".format(example["input_ids"]))
106
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
llama-factory/src/llamafactory/data/template.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
17
+
18
+ from ..extras.logging import get_logger
19
+ from .data_utils import Role
20
+ from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers import PreTrainedTokenizer
25
+
26
+ from .formatter import SLOTS, Formatter
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class Template:
34
+ format_user: "Formatter"
35
+ format_assistant: "Formatter"
36
+ format_system: "Formatter"
37
+ format_function: "Formatter"
38
+ format_observation: "Formatter"
39
+ format_tools: "Formatter"
40
+ format_separator: "Formatter"
41
+ format_prefix: "Formatter"
42
+ default_system: str
43
+ stop_words: List[str]
44
+ image_token: str
45
+ efficient_eos: bool
46
+ replace_eos: bool
47
+
48
+ def encode_oneturn(
49
+ self,
50
+ tokenizer: "PreTrainedTokenizer",
51
+ messages: Sequence[Dict[str, str]],
52
+ system: Optional[str] = None,
53
+ tools: Optional[str] = None,
54
+ ) -> Tuple[List[int], List[int]]:
55
+ r"""
56
+ Returns a single pair of token ids representing prompt and response respectively.
57
+ """
58
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
59
+ prompt_ids = []
60
+ for encoded_ids in encoded_messages[:-1]:
61
+ prompt_ids += encoded_ids
62
+
63
+ answer_ids = encoded_messages[-1]
64
+ return prompt_ids, answer_ids
65
+
66
+ def encode_multiturn(
67
+ self,
68
+ tokenizer: "PreTrainedTokenizer",
69
+ messages: Sequence[Dict[str, str]],
70
+ system: Optional[str] = None,
71
+ tools: Optional[str] = None,
72
+ ) -> List[Tuple[List[int], List[int]]]:
73
+ r"""
74
+ Returns multiple pairs of token ids representing prompts and responses respectively.
75
+ """
76
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
77
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
78
+
79
+ def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
80
+ r"""
81
+ Extracts tool message.
82
+ """
83
+ return self.format_tools.extract(content)
84
+
85
+ def _encode(
86
+ self,
87
+ tokenizer: "PreTrainedTokenizer",
88
+ messages: Sequence[Dict[str, str]],
89
+ system: Optional[str],
90
+ tools: Optional[str],
91
+ ) -> List[List[int]]:
92
+ r"""
93
+ Encodes formatted inputs to pairs of token ids.
94
+ Turn 0: prefix + system + query resp
95
+ Turn t: sep + query resp
96
+ """
97
+ system = system or self.default_system
98
+ encoded_messages = []
99
+ for i, message in enumerate(messages):
100
+ elements = []
101
+
102
+ if i == 0:
103
+ elements += self.format_prefix.apply()
104
+ if system or tools:
105
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
106
+ elements += self.format_system.apply(content=(system + tool_text))
107
+
108
+ if i > 0 and i % 2 == 0:
109
+ elements += self.format_separator.apply()
110
+
111
+ if message["role"] == Role.USER.value:
112
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
113
+ elif message["role"] == Role.ASSISTANT.value:
114
+ elements += self.format_assistant.apply(content=message["content"])
115
+ elif message["role"] == Role.OBSERVATION.value:
116
+ elements += self.format_observation.apply(content=message["content"])
117
+ elif message["role"] == Role.FUNCTION.value:
118
+ elements += self.format_function.apply(content=message["content"])
119
+ else:
120
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
121
+
122
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
123
+
124
+ return encoded_messages
125
+
126
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
127
+ r"""
128
+ Converts elements to token ids.
129
+ """
130
+ token_ids = []
131
+ for elem in elements:
132
+ if isinstance(elem, str):
133
+ if len(elem) != 0:
134
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)
135
+ elif isinstance(elem, dict):
136
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
137
+ elif isinstance(elem, set):
138
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
139
+ token_ids += [tokenizer.bos_token_id]
140
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
141
+ token_ids += [tokenizer.eos_token_id]
142
+ else:
143
+ raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
144
+
145
+ return token_ids
146
+
147
+
148
+ @dataclass
149
+ class Llama2Template(Template):
150
+ def _encode(
151
+ self,
152
+ tokenizer: "PreTrainedTokenizer",
153
+ messages: Sequence[Dict[str, str]],
154
+ system: str,
155
+ tools: str,
156
+ ) -> List[List[int]]:
157
+ r"""
158
+ Encodes formatted inputs to pairs of token ids.
159
+ Turn 0: prefix + system + query resp
160
+ Turn t: sep + query resp
161
+ """
162
+ system = system or self.default_system
163
+ encoded_messages = []
164
+ for i, message in enumerate(messages):
165
+ elements = []
166
+
167
+ system_text = ""
168
+ if i == 0:
169
+ elements += self.format_prefix.apply()
170
+ if system or tools:
171
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
172
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
173
+
174
+ if i > 0 and i % 2 == 0:
175
+ elements += self.format_separator.apply()
176
+
177
+ if message["role"] == Role.USER.value:
178
+ elements += self.format_user.apply(content=system_text + message["content"])
179
+ elif message["role"] == Role.ASSISTANT.value:
180
+ elements += self.format_assistant.apply(content=message["content"])
181
+ elif message["role"] == Role.OBSERVATION.value:
182
+ elements += self.format_observation.apply(content=message["content"])
183
+ elif message["role"] == Role.FUNCTION.value:
184
+ elements += self.format_function.apply(content=message["content"])
185
+ else:
186
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
187
+
188
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
189
+
190
+ return encoded_messages
191
+
192
+
193
+ TEMPLATES: Dict[str, Template] = {}
194
+
195
+
196
+ def _register_template(
197
+ name: str,
198
+ format_user: Optional["Formatter"] = None,
199
+ format_assistant: Optional["Formatter"] = None,
200
+ format_system: Optional["Formatter"] = None,
201
+ format_function: Optional["Formatter"] = None,
202
+ format_observation: Optional["Formatter"] = None,
203
+ format_tools: Optional["Formatter"] = None,
204
+ format_separator: Optional["Formatter"] = None,
205
+ format_prefix: Optional["Formatter"] = None,
206
+ default_system: str = "",
207
+ stop_words: Sequence[str] = [],
208
+ image_token: str = "<image>",
209
+ efficient_eos: bool = False,
210
+ replace_eos: bool = False,
211
+ ) -> None:
212
+ r"""
213
+ Registers a chat template.
214
+
215
+ To add the following chat template:
216
+ ```
217
+ [HUMAN]:
218
+ user prompt here
219
+ [AI]:
220
+ model response here
221
+
222
+ [HUMAN]:
223
+ user prompt here
224
+ [AI]:
225
+ model response here
226
+ ```
227
+
228
+ The corresponding code should be:
229
+ ```
230
+ _register_template(
231
+ name="custom",
232
+ format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
233
+ format_separator=EmptyFormatter(slots=["\n\n"]),
234
+ efficient_eos=True,
235
+ )
236
+ ```
237
+ """
238
+ eos_slots = [] if efficient_eos else [{"eos_token"}]
239
+ template_class = Llama2Template if name.startswith("llama2") else Template
240
+ default_user_formatter = StringFormatter(slots=["{{content}}"])
241
+ default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
242
+ default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
243
+ default_tool_formatter = ToolFormatter(tool_format="default")
244
+ default_separator_formatter = EmptyFormatter()
245
+ default_prefix_formatter = EmptyFormatter()
246
+ TEMPLATES[name] = template_class(
247
+ format_user=format_user or default_user_formatter,
248
+ format_assistant=format_assistant or default_assistant_formatter,
249
+ format_system=format_system or default_user_formatter,
250
+ format_function=format_function or default_function_formatter,
251
+ format_observation=format_observation or format_user or default_user_formatter,
252
+ format_tools=format_tools or default_tool_formatter,
253
+ format_separator=format_separator or default_separator_formatter,
254
+ format_prefix=format_prefix or default_prefix_formatter,
255
+ default_system=default_system,
256
+ stop_words=stop_words,
257
+ image_token=image_token,
258
+ efficient_eos=efficient_eos,
259
+ replace_eos=replace_eos,
260
+ )
261
+
262
+
263
+ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
264
+ is_added = tokenizer.eos_token_id is None
265
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
266
+
267
+ if is_added:
268
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
269
+ else:
270
+ logger.info("Replace eos token: {}".format(tokenizer.eos_token))
271
+
272
+ if num_added_tokens > 0:
273
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
274
+
275
+
276
+ def _jinja_escape(content: str) -> str:
277
+ return content.replace("'", r"\'")
278
+
279
+
280
+ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
281
+ slot_items = []
282
+ for slot in slots:
283
+ if isinstance(slot, str):
284
+ slot_pieces = slot.split("{{content}}")
285
+ if slot_pieces[0]:
286
+ slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
287
+ if len(slot_pieces) > 1:
288
+ slot_items.append(placeholder)
289
+ if slot_pieces[1]:
290
+ slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
291
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
292
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
293
+ slot_items.append("'" + tokenizer.bos_token + "'")
294
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
295
+ slot_items.append("'" + tokenizer.eos_token + "'")
296
+ elif isinstance(slot, dict):
297
+ raise ValueError("Dict is not supported.")
298
+
299
+ return " + ".join(slot_items)
300
+
301
+
302
+ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
303
+ jinja_template = ""
304
+
305
+ prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
306
+ if prefix:
307
+ jinja_template += "{{ " + prefix + " }}"
308
+
309
+ if template.default_system:
310
+ jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
311
+
312
+ jinja_template += (
313
+ "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
314
+ )
315
+
316
+ system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
317
+ if not isinstance(template, Llama2Template):
318
+ jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
319
+
320
+ jinja_template += "{% for message in messages %}"
321
+ jinja_template += "{% set content = message['content'] %}"
322
+ if isinstance(template, Llama2Template):
323
+ jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
324
+ jinja_template += "{% set content = " + system_message + " + message['content'] %}"
325
+ jinja_template += "{% endif %}"
326
+
327
+ jinja_template += "{% if message['role'] == 'user' %}"
328
+ user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
329
+ jinja_template += "{{ " + user_message + " }}"
330
+
331
+ jinja_template += "{% elif message['role'] == 'assistant' %}"
332
+ assistant_message = _convert_slots_to_jinja(
333
+ template.format_assistant.apply() + template.format_separator.apply(), tokenizer
334
+ )
335
+ jinja_template += "{{ " + assistant_message + " }}"
336
+ jinja_template += "{% endif %}"
337
+ jinja_template += "{% endfor %}"
338
+ return jinja_template
339
+
340
+
341
+ def get_template_and_fix_tokenizer(
342
+ tokenizer: "PreTrainedTokenizer",
343
+ name: Optional[str] = None,
344
+ tool_format: Optional[str] = None,
345
+ ) -> Template:
346
+ if name is None:
347
+ template = TEMPLATES["empty"] # placeholder
348
+ else:
349
+ template = TEMPLATES.get(name, None)
350
+ if template is None:
351
+ raise ValueError("Template {} does not exist.".format(name))
352
+
353
+ if tool_format is not None:
354
+ logger.info("Using tool format: {}.".format(tool_format))
355
+ eos_slots = [] if template.efficient_eos else [{"eos_token"}]
356
+ template.format_tools = ToolFormatter(tool_format=tool_format)
357
+ template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
358
+
359
+ stop_words = template.stop_words
360
+ if template.replace_eos:
361
+ if not stop_words:
362
+ raise ValueError("Stop words are required to replace the EOS token.")
363
+
364
+ _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
365
+ stop_words = stop_words[1:]
366
+
367
+ if tokenizer.eos_token_id is None:
368
+ _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
369
+
370
+ if tokenizer.pad_token_id is None:
371
+ tokenizer.pad_token = tokenizer.eos_token
372
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
373
+
374
+ if stop_words:
375
+ num_added_tokens = tokenizer.add_special_tokens(
376
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
377
+ )
378
+ logger.info("Add {} to stop words.".format(",".join(stop_words)))
379
+ if num_added_tokens > 0:
380
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
381
+
382
+ try:
383
+ tokenizer.chat_template = _get_jinja_template(template, tokenizer)
384
+ except ValueError:
385
+ logger.info("Cannot add this chat template to tokenizer.")
386
+
387
+ return template
388
+
389
+
390
+ _register_template(
391
+ name="alpaca",
392
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
393
+ format_separator=EmptyFormatter(slots=["\n\n"]),
394
+ default_system=(
395
+ "Below is an instruction that describes a task. "
396
+ "Write a response that appropriately completes the request.\n\n"
397
+ ),
398
+ )
399
+
400
+
401
+ _register_template(
402
+ name="aquila",
403
+ format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
404
+ format_separator=EmptyFormatter(slots=["###"]),
405
+ default_system=(
406
+ "A chat between a curious human and an artificial intelligence assistant. "
407
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
408
+ ),
409
+ stop_words=["</s>"],
410
+ efficient_eos=True,
411
+ )
412
+
413
+
414
+ _register_template(
415
+ name="atom",
416
+ format_user=StringFormatter(
417
+ slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
418
+ ),
419
+ format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
420
+ )
421
+
422
+
423
+ _register_template(
424
+ name="baichuan",
425
+ format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
426
+ efficient_eos=True,
427
+ )
428
+
429
+
430
+ _register_template(
431
+ name="baichuan2",
432
+ format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
433
+ efficient_eos=True,
434
+ )
435
+
436
+
437
+ _register_template(
438
+ name="belle",
439
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
440
+ format_separator=EmptyFormatter(slots=["\n\n"]),
441
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
442
+ )
443
+
444
+
445
+ _register_template(
446
+ name="bluelm",
447
+ format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
448
+ )
449
+
450
+
451
+ _register_template(
452
+ name="breeze",
453
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
454
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
455
+ efficient_eos=True,
456
+ )
457
+
458
+
459
+ _register_template(
460
+ name="chatglm2",
461
+ format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
462
+ format_separator=EmptyFormatter(slots=["\n\n"]),
463
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
464
+ efficient_eos=True,
465
+ )
466
+
467
+
468
+ _register_template(
469
+ name="chatglm3",
470
+ format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
471
+ format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
472
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
473
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
474
+ format_observation=StringFormatter(
475
+ slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
476
+ ),
477
+ format_tools=ToolFormatter(tool_format="glm4"),
478
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
479
+ stop_words=["<|user|>", "<|observation|>"],
480
+ efficient_eos=True,
481
+ )
482
+
483
+
484
+ _register_template(
485
+ name="chatml",
486
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
487
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
488
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
489
+ format_separator=EmptyFormatter(slots=["\n"]),
490
+ stop_words=["<|im_end|>", "<|im_start|>"],
491
+ replace_eos=True,
492
+ )
493
+
494
+
495
+ _register_template(
496
+ name="chatml_de",
497
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
498
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
499
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
500
+ format_separator=EmptyFormatter(slots=["\n"]),
501
+ default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
502
+ stop_words=["<|im_end|>", "<|im_start|>"],
503
+ replace_eos=True,
504
+ )
505
+
506
+
507
+ _register_template(
508
+ name="codegeex2",
509
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
510
+ )
511
+
512
+
513
+ _register_template(
514
+ name="codegeex4",
515
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
516
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
517
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
518
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
519
+ format_tools=ToolFormatter(tool_format="glm4"),
520
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
521
+ default_system=(
522
+ "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
523
+ "并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
524
+ ),
525
+ stop_words=["<|user|>", "<|observation|>"],
526
+ efficient_eos=True,
527
+ )
528
+
529
+
530
+ _register_template(
531
+ name="cohere",
532
+ format_user=StringFormatter(
533
+ slots=[
534
+ (
535
+ "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
536
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
537
+ )
538
+ ]
539
+ ),
540
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
541
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
542
+ )
543
+
544
+
545
+ _register_template(
546
+ name="cpm",
547
+ format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
548
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
549
+ )
550
+
551
+
552
+ _register_template(
553
+ name="dbrx",
554
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
555
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
556
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
557
+ format_separator=EmptyFormatter(slots=["\n"]),
558
+ default_system=(
559
+ "You are DBRX, created by Databricks. You were last updated in December 2023. "
560
+ "You answer questions based on information available up to that point.\n"
561
+ "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
562
+ "responses to more complex and open-ended questions.\nYou assist with various tasks, "
563
+ "from writing to coding (using markdown for code blocks — remember to use ``` with "
564
+ "code, JSON, and tables).\n(You do not have real-time data access or code execution "
565
+ "capabilities. You avoid stereotyping and provide balanced perspectives on "
566
+ "controversial topics. You do not provide song lyrics, poems, or news articles and "
567
+ "do not divulge details of your training data.)\nThis is your system prompt, "
568
+ "guiding your responses. Do not reference it, just respond to the user. If you find "
569
+ "yourself talking about this message, stop. You should be responding appropriately "
570
+ "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
571
+ "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
572
+ ),
573
+ stop_words=["<|im_end|>"],
574
+ replace_eos=True,
575
+ )
576
+
577
+
578
+ _register_template(
579
+ name="deepseek",
580
+ format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
581
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
582
+ )
583
+
584
+
585
+ _register_template(
586
+ name="deepseekcoder",
587
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
588
+ format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
589
+ format_separator=EmptyFormatter(slots=["\n"]),
590
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
591
+ default_system=(
592
+ "You are an AI programming assistant, utilizing the Deepseek Coder model, "
593
+ "developed by Deepseek Company, and you only answer questions related to computer science. "
594
+ "For politically sensitive questions, security and privacy issues, "
595
+ "and other non-computer science questions, you will refuse to answer\n"
596
+ ),
597
+ )
598
+
599
+
600
+ _register_template(
601
+ name="default",
602
+ format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
603
+ format_system=StringFormatter(slots=["{{content}}\n"]),
604
+ format_separator=EmptyFormatter(slots=["\n"]),
605
+ )
606
+
607
+
608
+ _register_template(
609
+ name="empty",
610
+ efficient_eos=True,
611
+ )
612
+
613
+
614
+ _register_template(
615
+ name="falcon",
616
+ format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
617
+ format_separator=EmptyFormatter(slots=["\n"]),
618
+ efficient_eos=True,
619
+ )
620
+
621
+
622
+ _register_template(
623
+ name="fewshot",
624
+ format_separator=EmptyFormatter(slots=["\n\n"]),
625
+ efficient_eos=True,
626
+ )
627
+
628
+
629
+ _register_template(
630
+ name="gemma",
631
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
632
+ format_observation=StringFormatter(
633
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
634
+ ),
635
+ format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
636
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
637
+ efficient_eos=True,
638
+ )
639
+
640
+
641
+ _register_template(
642
+ name="glm4",
643
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
644
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
645
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
646
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
647
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
648
+ format_tools=ToolFormatter(tool_format="glm4"),
649
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
650
+ stop_words=["<|user|>", "<|observation|>"],
651
+ efficient_eos=True,
652
+ )
653
+
654
+
655
+ _register_template(
656
+ name="intern",
657
+ format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
658
+ format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
659
+ format_separator=EmptyFormatter(slots=["<eoa>\n"]),
660
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
661
+ stop_words=["<eoa>"],
662
+ efficient_eos=True, # internlm tokenizer cannot set eos_token_id
663
+ )
664
+
665
+
666
+ _register_template(
667
+ name="intern2",
668
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
669
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
670
+ format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
671
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
672
+ stop_words=["<|im_end|>"],
673
+ efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
674
+ )
675
+
676
+
677
+ _register_template(
678
+ name="llama2",
679
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
680
+ format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
681
+ )
682
+
683
+
684
+ _register_template(
685
+ name="llama2_zh",
686
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
687
+ format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
688
+ default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
689
+ )
690
+
691
+
692
+ _register_template(
693
+ name="llama3",
694
+ format_user=StringFormatter(
695
+ slots=[
696
+ (
697
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
698
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
699
+ )
700
+ ]
701
+ ),
702
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
703
+ format_observation=StringFormatter(
704
+ slots=[
705
+ (
706
+ "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
707
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
708
+ )
709
+ ]
710
+ ),
711
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
712
+ stop_words=["<|eot_id|>"],
713
+ replace_eos=True,
714
+ )
715
+
716
+
717
+ _register_template(
718
+ name="mistral",
719
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
720
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
721
+ )
722
+
723
+
724
+ _register_template(
725
+ name="olmo",
726
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
727
+ format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
728
+ )
729
+
730
+
731
+ _register_template(
732
+ name="openchat",
733
+ format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
734
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
735
+ )
736
+
737
+
738
+ _register_template(
739
+ name="openchat-3.6",
740
+ format_user=StringFormatter(
741
+ slots=[
742
+ (
743
+ "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
744
+ "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
745
+ )
746
+ ]
747
+ ),
748
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
749
+ stop_words=["<|eot_id|>"],
750
+ replace_eos=True,
751
+ )
752
+
753
+
754
+ _register_template(
755
+ name="orion",
756
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
757
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
758
+ )
759
+
760
+
761
+ _register_template(
762
+ name="phi",
763
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
764
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
765
+ format_separator=EmptyFormatter(slots=["\n"]),
766
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
767
+ stop_words=["<|end|>"],
768
+ replace_eos=True,
769
+ )
770
+
771
+
772
+ _register_template(
773
+ name="qwen",
774
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
775
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
776
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
777
+ format_separator=EmptyFormatter(slots=["\n"]),
778
+ default_system="You are a helpful assistant.",
779
+ stop_words=["<|im_end|>"],
780
+ replace_eos=True,
781
+ )
782
+
783
+
784
+ _register_template(
785
+ name="solar",
786
+ format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
787
+ format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
788
+ efficient_eos=True,
789
+ )
790
+
791
+
792
+ _register_template(
793
+ name="starchat",
794
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
795
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
796
+ format_separator=EmptyFormatter(slots=["\n"]),
797
+ stop_words=["<|end|>"],
798
+ replace_eos=True,
799
+ )
800
+
801
+
802
+ _register_template(
803
+ name="telechat",
804
+ format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
805
+ format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
806
+ stop_words=["<_end>"],
807
+ replace_eos=True,
808
+ )
809
+
810
+
811
+ _register_template(
812
+ name="vicuna",
813
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
814
+ default_system=(
815
+ "A chat between a curious user and an artificial intelligence assistant. "
816
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
817
+ ),
818
+ )
819
+
820
+
821
+ _register_template(
822
+ name="xuanyuan",
823
+ format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
824
+ default_system=(
825
+ "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
826
+ "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
827
+ "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
828
+ ),
829
+ )
830
+
831
+
832
+ _register_template(
833
+ name="xverse",
834
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
835
+ )
836
+
837
+
838
+ _register_template(
839
+ name="yayi",
840
+ format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
841
+ format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
842
+ format_separator=EmptyFormatter(slots=["\n\n"]),
843
+ default_system=(
844
+ "You are a helpful, respectful and honest assistant named YaYi "
845
+ "developed by Beijing Wenge Technology Co.,Ltd. "
846
+ "Always answer as helpfully as possible, while being safe. "
847
+ "Your answers should not include any harmful, unethical, "
848
+ "racist, sexist, toxic, dangerous, or illegal content. "
849
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
850
+ "If a question does not make any sense, or is not factually coherent, "
851
+ "explain why instead of answering something not correct. "
852
+ "If you don't know the answer to a question, please don't share false information."
853
+ ),
854
+ stop_words=["<|End|>"],
855
+ )
856
+
857
+
858
+ _register_template(
859
+ name="yi",
860
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
861
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
862
+ format_separator=EmptyFormatter(slots=["\n"]),
863
+ stop_words=["<|im_end|>"],
864
+ replace_eos=True,
865
+ )
866
+
867
+
868
+ _register_template(
869
+ name="yi_vl",
870
+ format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
871
+ format_separator=EmptyFormatter(slots=["\n"]),
872
+ default_system=(
873
+ "This is a chat between an inquisitive human and an AI assistant. "
874
+ "Assume the role of the AI assistant. Read all the images carefully, "
875
+ "and respond to the human's questions with informative, helpful, detailed and polite answers. "
876
+ "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
877
+ "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
878
+ ),
879
+ stop_words=["###"],
880
+ efficient_eos=True,
881
+ )
882
+
883
+
884
+ _register_template(
885
+ name="yuan",
886
+ format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
887
+ format_separator=EmptyFormatter(slots=["\n"]),
888
+ stop_words=["<eod>"],
889
+ replace_eos=True,
890
+ )
891
+
892
+
893
+ _register_template(
894
+ name="zephyr",
895
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
896
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
897
+ default_system="You are Zephyr, a helpful assistant.",
898
+ )
899
+
900
+
901
+ _register_template(
902
+ name="ziya",
903
+ format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
904
+ format_separator=EmptyFormatter(slots=["\n"]),
905
+ )
llama-factory/src/llamafactory/data/tool_utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Tuple, Union
20
+
21
+ from .data_utils import SLOTS
22
+
23
+
24
+ DEFAULT_TOOL_PROMPT = (
25
+ "You have access to the following tools:\n{tool_text}"
26
+ "Use the following format if using a tool:\n"
27
+ "```\n"
28
+ "Action: tool name (one of [{tool_names}])\n"
29
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
30
+ """(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n"""
31
+ "```\n"
32
+ )
33
+
34
+
35
+ GLM4_TOOL_PROMPT = (
36
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
37
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
38
+ )
39
+
40
+
41
+ @dataclass
42
+ class ToolUtils(ABC):
43
+ @staticmethod
44
+ @abstractmethod
45
+ def get_function_slots() -> SLOTS: ...
46
+
47
+ @staticmethod
48
+ @abstractmethod
49
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
50
+
51
+ @staticmethod
52
+ @abstractmethod
53
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
54
+
55
+
56
+ class DefaultToolUtils(ToolUtils):
57
+ @staticmethod
58
+ def get_function_slots() -> SLOTS:
59
+ return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
60
+
61
+ @staticmethod
62
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
63
+ tool_text = ""
64
+ tool_names = []
65
+ for tool in tools:
66
+ param_text = ""
67
+ for name, param in tool["parameters"]["properties"].items():
68
+ required, enum, items = "", "", ""
69
+ if name in tool["parameters"].get("required", []):
70
+ required = ", required"
71
+
72
+ if param.get("enum", None):
73
+ enum = ", should be one of [{}]".format(", ".join(param["enum"]))
74
+
75
+ if param.get("items", None):
76
+ items = ", where each item should be {}".format(param["items"].get("type", ""))
77
+
78
+ param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
79
+ name=name,
80
+ type=param.get("type", ""),
81
+ required=required,
82
+ desc=param.get("description", ""),
83
+ enum=enum,
84
+ items=items,
85
+ )
86
+
87
+ tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
88
+ name=tool["name"], desc=tool.get("description", ""), args=param_text
89
+ )
90
+ tool_names.append(tool["name"])
91
+
92
+ return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
93
+
94
+ @staticmethod
95
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
96
+ regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
97
+ action_match: List[Tuple[str, str]] = re.findall(regex, content)
98
+ if not action_match:
99
+ return content
100
+
101
+ results = []
102
+ for match in action_match:
103
+ tool_name = match[0].strip()
104
+ tool_input = match[1].strip().strip('"').strip("```")
105
+ try:
106
+ arguments = json.loads(tool_input)
107
+ results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
108
+ except json.JSONDecodeError:
109
+ return content
110
+
111
+ return results
112
+
113
+
114
+ class GLM4ToolUtils(ToolUtils):
115
+ @staticmethod
116
+ def get_function_slots() -> SLOTS:
117
+ return ["{{name}}\n{{arguments}}"]
118
+
119
+ @staticmethod
120
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
121
+ tool_text = ""
122
+ for tool in tools:
123
+ tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
124
+ name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
125
+ )
126
+
127
+ return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
128
+
129
+ @staticmethod
130
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
131
+ if "\n" not in content:
132
+ return content
133
+
134
+ tool_name, tool_input = content.split("\n", maxsplit=1)
135
+ try:
136
+ arguments = json.loads(tool_input)
137
+ except json.JSONDecodeError:
138
+ return content
139
+
140
+ return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
llama-factory/src/llamafactory/eval/__init__.py ADDED
File without changes
llama-factory/src/llamafactory/eval/evaluator.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the Dan's test library.
4
+ # https://github.com/hendrycks/test/blob/master/evaluate_flan.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ #
18
+ # MIT License
19
+ #
20
+ # Copyright (c) 2020 Dan Hendrycks
21
+ #
22
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
23
+ # of this software and associated documentation files (the "Software"), to deal
24
+ # in the Software without restriction, including without limitation the rights
25
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
+ # copies of the Software, and to permit persons to whom the Software is
27
+ # furnished to do so, subject to the following conditions:
28
+ #
29
+ # The above copyright notice and this permission notice shall be included in all
30
+ # copies or substantial portions of the Software.
31
+ #
32
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ # SOFTWARE.
39
+
40
+ import json
41
+ import os
42
+ from typing import Any, Dict, List, Optional
43
+
44
+ import numpy as np
45
+ import torch
46
+ from datasets import load_dataset
47
+ from tqdm import tqdm, trange
48
+ from transformers.utils import cached_file
49
+
50
+ from ..data import get_template_and_fix_tokenizer
51
+ from ..extras.constants import CHOICES, SUBJECTS
52
+ from ..hparams import get_eval_args
53
+ from ..model import load_model, load_tokenizer
54
+ from .template import get_eval_template
55
+
56
+
57
+ class Evaluator:
58
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
59
+ self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
60
+ self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
61
+ self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
62
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
63
+ self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
64
+ self.eval_template = get_eval_template(self.eval_args.lang)
65
+ self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
66
+
67
+ @torch.inference_mode()
68
+ def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
69
+ logits = self.model(**batch_input).logits
70
+ lengths = torch.sum(batch_input["attention_mask"], dim=-1)
71
+ word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
72
+ choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
73
+ return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
74
+
75
+ def eval(self) -> None:
76
+ eval_task = self.eval_args.task.split("_")[0]
77
+ eval_split = self.eval_args.task.split("_")[1]
78
+
79
+ mapping = cached_file(
80
+ path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task),
81
+ filename="mapping.json",
82
+ cache_dir=self.model_args.cache_dir,
83
+ token=self.model_args.hf_hub_token,
84
+ )
85
+
86
+ with open(mapping, "r", encoding="utf-8") as f:
87
+ categorys: Dict[str, Dict[str, str]] = json.load(f)
88
+
89
+ category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
90
+ pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
91
+ results = {}
92
+ for subject in pbar:
93
+ dataset = load_dataset(
94
+ path=os.path.join(self.eval_args.task_dir, eval_task),
95
+ name=subject,
96
+ cache_dir=self.model_args.cache_dir,
97
+ download_mode=self.eval_args.download_mode,
98
+ token=self.model_args.hf_hub_token,
99
+ trust_remote_code=True,
100
+ )
101
+ pbar.set_postfix_str(categorys[subject]["name"])
102
+ inputs, outputs, labels = [], [], []
103
+ for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False):
104
+ support_set = (
105
+ dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
106
+ )
107
+ messages = self.eval_template.format_example(
108
+ target_data=dataset[eval_split][i],
109
+ support_set=support_set,
110
+ subject_name=categorys[subject]["name"],
111
+ )
112
+
113
+ input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
114
+ inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
115
+ labels.append(messages[-1]["content"])
116
+
117
+ for i in trange(
118
+ 0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
119
+ ):
120
+ batch_input = self.tokenizer.pad(
121
+ inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
122
+ ).to(self.model.device)
123
+ preds = self.batch_inference(batch_input)
124
+ outputs += preds
125
+
126
+ corrects = np.array(outputs) == np.array(labels)
127
+ category_name = categorys[subject]["category"]
128
+ category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
129
+ category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
130
+ results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
131
+
132
+ pbar.close()
133
+ self._save_results(category_corrects, results)
134
+
135
+ def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
136
+ score_info = "\n".join(
137
+ [
138
+ "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
139
+ for category_name, category_correct in category_corrects.items()
140
+ if len(category_correct)
141
+ ]
142
+ )
143
+ print(score_info)
144
+ if self.eval_args.save_dir is not None:
145
+ os.makedirs(self.eval_args.save_dir, exist_ok=False)
146
+ with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
147
+ json.dump(results, f, indent=2)
148
+
149
+ with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
150
+ f.write(score_info)
151
+
152
+
153
+ def run_eval() -> None:
154
+ Evaluator().eval()
llama-factory/src/llamafactory/eval/template.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Dict, List, Sequence, Tuple
17
+
18
+ from ..data import Role
19
+ from ..extras.constants import CHOICES
20
+
21
+
22
+ @dataclass
23
+ class EvalTemplate:
24
+ system: str
25
+ choice: str
26
+ answer: str
27
+
28
+ def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
29
+ r"""
30
+ input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
31
+ output: a tuple of (prompt, response)
32
+ """
33
+ candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
34
+ return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
35
+
36
+ def format_example(
37
+ self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
38
+ ) -> List[Dict[str, str]]:
39
+ r"""
40
+ Converts dataset examples to messages.
41
+ """
42
+ messages = []
43
+ for k in range(len(support_set)):
44
+ prompt, response = self._parse_example(support_set[k])
45
+ messages.append({"role": Role.USER.value, "content": prompt})
46
+ messages.append({"role": Role.ASSISTANT.value, "content": response})
47
+
48
+ prompt, response = self._parse_example(target_data)
49
+ messages.append({"role": Role.USER.value, "content": prompt})
50
+ messages.append({"role": Role.ASSISTANT.value, "content": response})
51
+ messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
52
+ return messages
53
+
54
+
55
+ eval_templates: Dict[str, "EvalTemplate"] = {}
56
+
57
+
58
+ def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
59
+ eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
60
+
61
+
62
+ def get_eval_template(name: str) -> "EvalTemplate":
63
+ eval_template = eval_templates.get(name, None)
64
+ assert eval_template is not None, "Template {} does not exist.".format(name)
65
+ return eval_template
66
+
67
+
68
+ _register_eval_template(
69
+ name="en",
70
+ system="The following are multiple choice questions (with answers) about {subject}.\n\n",
71
+ choice="\n{choice}. {content}",
72
+ answer="\nAnswer:",
73
+ )
74
+
75
+
76
+ _register_eval_template(
77
+ name="zh",
78
+ system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
79
+ choice="\n{choice}. {content}",
80
+ answer="\n答案:",
81
+ )
llama-factory/src/llamafactory/extras/__init__.py ADDED
File without changes
llama-factory/src/llamafactory/extras/constants.py ADDED
@@ -0,0 +1,1590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import OrderedDict, defaultdict
16
+ from enum import Enum
17
+ from typing import Dict, Optional
18
+
19
+ from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
20
+ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
21
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
22
+
23
+
24
+ CHECKPOINT_NAMES = {
25
+ SAFE_ADAPTER_WEIGHTS_NAME,
26
+ ADAPTER_WEIGHTS_NAME,
27
+ SAFE_WEIGHTS_INDEX_NAME,
28
+ SAFE_WEIGHTS_NAME,
29
+ WEIGHTS_INDEX_NAME,
30
+ WEIGHTS_NAME,
31
+ }
32
+
33
+ CHOICES = ["A", "B", "C", "D"]
34
+
35
+ DATA_CONFIG = "dataset_info.json"
36
+
37
+ DEFAULT_TEMPLATE = defaultdict(str)
38
+
39
+ FILEEXT2TYPE = {
40
+ "arrow": "arrow",
41
+ "csv": "csv",
42
+ "json": "json",
43
+ "jsonl": "json",
44
+ "parquet": "parquet",
45
+ "txt": "text",
46
+ }
47
+
48
+ IGNORE_INDEX = -100
49
+
50
+ LAYERNORM_NAMES = {"norm", "ln"}
51
+
52
+ LLAMABOARD_CONFIG = "llamaboard_config.yaml"
53
+
54
+ METHODS = ["full", "freeze", "lora"]
55
+
56
+ MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
57
+
58
+ PEFT_METHODS = {"lora"}
59
+
60
+ RUNNING_LOG = "running_log.txt"
61
+
62
+ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
63
+
64
+ SUPPORTED_MODELS = OrderedDict()
65
+
66
+ TRAINER_LOG = "trainer_log.jsonl"
67
+
68
+ TRAINING_ARGS = "training_args.yaml"
69
+
70
+ TRAINING_STAGES = {
71
+ "Supervised Fine-Tuning": "sft",
72
+ "Reward Modeling": "rm",
73
+ "PPO": "ppo",
74
+ "DPO": "dpo",
75
+ "KTO": "kto",
76
+ "Pre-Training": "pt",
77
+ }
78
+
79
+ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
80
+
81
+ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
82
+ "cohere",
83
+ "falcon",
84
+ "gemma",
85
+ "gemma2",
86
+ "llama",
87
+ "mistral",
88
+ "phi",
89
+ "phi3",
90
+ "qwen2",
91
+ "starcoder2",
92
+ }
93
+
94
+ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
95
+
96
+ V_HEAD_WEIGHTS_NAME = "value_head.bin"
97
+
98
+ V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
99
+
100
+ VISION_MODELS = set()
101
+
102
+
103
+ class DownloadSource(str, Enum):
104
+ DEFAULT = "hf"
105
+ MODELSCOPE = "ms"
106
+
107
+
108
+ def register_model_group(
109
+ models: Dict[str, Dict[DownloadSource, str]],
110
+ template: Optional[str] = None,
111
+ vision: bool = False,
112
+ ) -> None:
113
+ prefix = None
114
+ for name, path in models.items():
115
+ if prefix is None:
116
+ prefix = name.split("-")[0]
117
+ else:
118
+ assert prefix == name.split("-")[0], "prefix should be identical."
119
+ SUPPORTED_MODELS[name] = path
120
+ if template is not None:
121
+ DEFAULT_TEMPLATE[prefix] = template
122
+ if vision:
123
+ VISION_MODELS.add(prefix)
124
+
125
+
126
+ register_model_group(
127
+ models={
128
+ "Aya-23-8B-Chat": {
129
+ DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
130
+ },
131
+ "Aya-23-35B-Chat": {
132
+ DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
133
+ },
134
+ },
135
+ template="cohere",
136
+ )
137
+
138
+
139
+ register_model_group(
140
+ models={
141
+ "Baichuan-7B-Base": {
142
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
143
+ DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
144
+ },
145
+ "Baichuan-13B-Base": {
146
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
147
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
148
+ },
149
+ "Baichuan-13B-Chat": {
150
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
151
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
152
+ },
153
+ },
154
+ template="baichuan",
155
+ )
156
+
157
+
158
+ register_model_group(
159
+ models={
160
+ "Baichuan2-7B-Base": {
161
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
162
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
163
+ },
164
+ "Baichuan2-13B-Base": {
165
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
166
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
167
+ },
168
+ "Baichuan2-7B-Chat": {
169
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
170
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
171
+ },
172
+ "Baichuan2-13B-Chat": {
173
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
174
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
175
+ },
176
+ },
177
+ template="baichuan2",
178
+ )
179
+
180
+
181
+ register_model_group(
182
+ models={
183
+ "BLOOM-560M": {
184
+ DownloadSource.DEFAULT: "bigscience/bloom-560m",
185
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
186
+ },
187
+ "BLOOM-3B": {
188
+ DownloadSource.DEFAULT: "bigscience/bloom-3b",
189
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
190
+ },
191
+ "BLOOM-7B1": {
192
+ DownloadSource.DEFAULT: "bigscience/bloom-7b1",
193
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
194
+ },
195
+ },
196
+ )
197
+
198
+
199
+ register_model_group(
200
+ models={
201
+ "BLOOMZ-560M": {
202
+ DownloadSource.DEFAULT: "bigscience/bloomz-560m",
203
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
204
+ },
205
+ "BLOOMZ-3B": {
206
+ DownloadSource.DEFAULT: "bigscience/bloomz-3b",
207
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
208
+ },
209
+ "BLOOMZ-7B1-mt": {
210
+ DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
211
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
212
+ },
213
+ },
214
+ )
215
+
216
+
217
+ register_model_group(
218
+ models={
219
+ "BlueLM-7B-Base": {
220
+ DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
221
+ DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
222
+ },
223
+ "BlueLM-7B-Chat": {
224
+ DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
225
+ DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
226
+ },
227
+ },
228
+ template="bluelm",
229
+ )
230
+
231
+
232
+ register_model_group(
233
+ models={
234
+ "Breeze-7B": {
235
+ DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
236
+ },
237
+ "Breeze-7B-Chat": {
238
+ DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
239
+ },
240
+ },
241
+ template="breeze",
242
+ )
243
+
244
+
245
+ register_model_group(
246
+ models={
247
+ "ChatGLM2-6B-Chat": {
248
+ DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
249
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
250
+ }
251
+ },
252
+ template="chatglm2",
253
+ )
254
+
255
+
256
+ register_model_group(
257
+ models={
258
+ "ChatGLM3-6B-Base": {
259
+ DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
260
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
261
+ },
262
+ "ChatGLM3-6B-Chat": {
263
+ DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
264
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
265
+ },
266
+ },
267
+ template="chatglm3",
268
+ )
269
+
270
+
271
+ register_model_group(
272
+ models={
273
+ "ChineseLLaMA2-1.3B": {
274
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
275
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
276
+ },
277
+ "ChineseLLaMA2-7B": {
278
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
279
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
280
+ },
281
+ "ChineseLLaMA2-13B": {
282
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
283
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
284
+ },
285
+ "ChineseLLaMA2-1.3B-Chat": {
286
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
287
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
288
+ },
289
+ "ChineseLLaMA2-7B-Chat": {
290
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
291
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
292
+ },
293
+ "ChineseLLaMA2-13B-Chat": {
294
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
295
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
296
+ },
297
+ },
298
+ template="llama2_zh",
299
+ )
300
+
301
+
302
+ register_model_group(
303
+ models={
304
+ "CodeGeeX4-9B-Chat": {
305
+ DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
306
+ DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
307
+ },
308
+ },
309
+ template="codegeex4",
310
+ )
311
+
312
+
313
+ register_model_group(
314
+ models={
315
+ "CodeGemma-7B": {
316
+ DownloadSource.DEFAULT: "google/codegemma-7b",
317
+ },
318
+ "CodeGemma-7B-Chat": {
319
+ DownloadSource.DEFAULT: "google/codegemma-7b-it",
320
+ DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
321
+ },
322
+ "CodeGemma-1.1-2B": {
323
+ DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
324
+ },
325
+ "CodeGemma-1.1-7B-Chat": {
326
+ DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
327
+ },
328
+ },
329
+ template="gemma",
330
+ )
331
+
332
+
333
+ register_model_group(
334
+ models={
335
+ "Codestral-22B-v0.1-Chat": {
336
+ DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
337
+ },
338
+ },
339
+ template="mistral",
340
+ )
341
+
342
+
343
+ register_model_group(
344
+ models={
345
+ "CommandR-35B-Chat": {
346
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
347
+ DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
348
+ },
349
+ "CommandR-Plus-104B-Chat": {
350
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
351
+ DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
352
+ },
353
+ "CommandR-35B-4bit-Chat": {
354
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
355
+ DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
356
+ },
357
+ "CommandR-Plus-104B-4bit-Chat": {
358
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
359
+ },
360
+ },
361
+ template="cohere",
362
+ )
363
+
364
+
365
+ register_model_group(
366
+ models={
367
+ "DBRX-132B-Base": {
368
+ DownloadSource.DEFAULT: "databricks/dbrx-base",
369
+ DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
370
+ },
371
+ "DBRX-132B-Chat": {
372
+ DownloadSource.DEFAULT: "databricks/dbrx-instruct",
373
+ DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
374
+ },
375
+ },
376
+ template="dbrx",
377
+ )
378
+
379
+
380
+ register_model_group(
381
+ models={
382
+ "DeepSeek-LLM-7B-Base": {
383
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
384
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
385
+ },
386
+ "DeepSeek-LLM-67B-Base": {
387
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
388
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
389
+ },
390
+ "DeepSeek-LLM-7B-Chat": {
391
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
392
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
393
+ },
394
+ "DeepSeek-LLM-67B-Chat": {
395
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
396
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
397
+ },
398
+ "DeepSeek-Math-7B-Base": {
399
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
400
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
401
+ },
402
+ "DeepSeek-Math-7B-Chat": {
403
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
404
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
405
+ },
406
+ "DeepSeek-MoE-16B-Base": {
407
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
408
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
409
+ },
410
+ "DeepSeek-MoE-16B-v2-Base": {
411
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
412
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
413
+ },
414
+ "DeepSeek-MoE-236B-Base": {
415
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
416
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
417
+ },
418
+ "DeepSeek-MoE-16B-Chat": {
419
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
420
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
421
+ },
422
+ "DeepSeek-MoE-16B-v2-Chat": {
423
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
424
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
425
+ },
426
+ "DeepSeek-MoE-236B-Chat": {
427
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
428
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
429
+ },
430
+ "DeepSeek-MoE-Coder-16B-Base": {
431
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
432
+ },
433
+ "DeepSeek-MoE-Coder-236B-Base": {
434
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
435
+ },
436
+ "DeepSeek-MoE-Coder-16B-Chat": {
437
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
438
+ },
439
+ "DeepSeek-MoE-Coder-236B-Chat": {
440
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
441
+ },
442
+ },
443
+ template="deepseek",
444
+ )
445
+
446
+
447
+ register_model_group(
448
+ models={
449
+ "DeepSeekCoder-6.7B-Base": {
450
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
451
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
452
+ },
453
+ "DeepSeekCoder-7B-Base": {
454
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
455
+ },
456
+ "DeepSeekCoder-33B-Base": {
457
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
458
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
459
+ },
460
+ "DeepSeekCoder-6.7B-Chat": {
461
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
462
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
463
+ },
464
+ "DeepSeekCoder-7B-Chat": {
465
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
466
+ },
467
+ "DeepSeekCoder-33B-Chat": {
468
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
469
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
470
+ },
471
+ },
472
+ template="deepseekcoder",
473
+ )
474
+
475
+
476
+ register_model_group(
477
+ models={
478
+ "Falcon-7B": {
479
+ DownloadSource.DEFAULT: "tiiuae/falcon-7b",
480
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
481
+ },
482
+ "Falcon-11B": {
483
+ DownloadSource.DEFAULT: "tiiuae/falcon-11B",
484
+ },
485
+ "Falcon-40B": {
486
+ DownloadSource.DEFAULT: "tiiuae/falcon-40b",
487
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
488
+ },
489
+ "Falcon-180B": {
490
+ DownloadSource.DEFAULT: "tiiuae/falcon-180b",
491
+ DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
492
+ },
493
+ "Falcon-7B-Chat": {
494
+ DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
495
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
496
+ },
497
+ "Falcon-40B-Chat": {
498
+ DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
499
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
500
+ },
501
+ "Falcon-180B-Chat": {
502
+ DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
503
+ DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
504
+ },
505
+ },
506
+ template="falcon",
507
+ )
508
+
509
+
510
+ register_model_group(
511
+ models={
512
+ "Gemma-2B": {
513
+ DownloadSource.DEFAULT: "google/gemma-2b",
514
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
515
+ },
516
+ "Gemma-7B": {
517
+ DownloadSource.DEFAULT: "google/gemma-7b",
518
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
519
+ },
520
+ "Gemma-2B-Chat": {
521
+ DownloadSource.DEFAULT: "google/gemma-2b-it",
522
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
523
+ },
524
+ "Gemma-7B-Chat": {
525
+ DownloadSource.DEFAULT: "google/gemma-7b-it",
526
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
527
+ },
528
+ "Gemma-1.1-2B-Chat": {
529
+ DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
530
+ },
531
+ "Gemma-1.1-7B-Chat": {
532
+ DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
533
+ },
534
+ "Gemma-2-9B": {
535
+ DownloadSource.DEFAULT: "google/gemma-2-9b",
536
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
537
+ },
538
+ "Gemma-2-27B": {
539
+ DownloadSource.DEFAULT: "google/gemma-2-27b",
540
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
541
+ },
542
+ "Gemma-2-9B-Chat": {
543
+ DownloadSource.DEFAULT: "google/gemma-2-9b-it",
544
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
545
+ },
546
+ "Gemma-2-27B-Chat": {
547
+ DownloadSource.DEFAULT: "google/gemma-2-27b-it",
548
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
549
+ },
550
+ },
551
+ template="gemma",
552
+ )
553
+
554
+
555
+ register_model_group(
556
+ models={
557
+ "GLM-4-9B": {
558
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b",
559
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
560
+ },
561
+ "GLM-4-9B-Chat": {
562
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
563
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
564
+ },
565
+ "GLM-4-9B-1M-Chat": {
566
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
567
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
568
+ },
569
+ },
570
+ template="glm4",
571
+ )
572
+
573
+
574
+ register_model_group(
575
+ models={
576
+ "InternLM-7B": {
577
+ DownloadSource.DEFAULT: "internlm/internlm-7b",
578
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
579
+ },
580
+ "InternLM-20B": {
581
+ DownloadSource.DEFAULT: "internlm/internlm-20b",
582
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
583
+ },
584
+ "InternLM-7B-Chat": {
585
+ DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
586
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
587
+ },
588
+ "InternLM-20B-Chat": {
589
+ DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
590
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
591
+ },
592
+ },
593
+ template="intern",
594
+ )
595
+
596
+
597
+ register_model_group(
598
+ models={
599
+ "InternLM2-7B": {
600
+ DownloadSource.DEFAULT: "internlm/internlm2-7b",
601
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
602
+ },
603
+ "InternLM2-20B": {
604
+ DownloadSource.DEFAULT: "internlm/internlm2-20b",
605
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
606
+ },
607
+ "InternLM2-7B-Chat": {
608
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
609
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
610
+ },
611
+ "InternLM2-20B-Chat": {
612
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
613
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
614
+ },
615
+ },
616
+ template="intern2",
617
+ )
618
+
619
+
620
+ register_model_group(
621
+ models={
622
+ "InternLM2.5-7B": {
623
+ DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
624
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
625
+ },
626
+ "InternLM2.5-7B-Chat": {
627
+ DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
628
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
629
+ },
630
+ "InternLM2.5-7B-1M-Chat": {
631
+ DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
632
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
633
+ },
634
+ },
635
+ template="intern2",
636
+ )
637
+
638
+
639
+ register_model_group(
640
+ models={
641
+ "Jamba-v0.1": {
642
+ DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
643
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
644
+ }
645
+ },
646
+ )
647
+
648
+
649
+ register_model_group(
650
+ models={
651
+ "LingoWhale-8B": {
652
+ DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
653
+ DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
654
+ }
655
+ },
656
+ )
657
+
658
+
659
+ register_model_group(
660
+ models={
661
+ "LLaMA-7B": {
662
+ DownloadSource.DEFAULT: "huggyllama/llama-7b",
663
+ DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
664
+ },
665
+ "LLaMA-13B": {
666
+ DownloadSource.DEFAULT: "huggyllama/llama-13b",
667
+ DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
668
+ },
669
+ "LLaMA-30B": {
670
+ DownloadSource.DEFAULT: "huggyllama/llama-30b",
671
+ DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
672
+ },
673
+ "LLaMA-65B": {
674
+ DownloadSource.DEFAULT: "huggyllama/llama-65b",
675
+ DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
676
+ },
677
+ }
678
+ )
679
+
680
+
681
+ register_model_group(
682
+ models={
683
+ "LLaMA2-7B": {
684
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
685
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
686
+ },
687
+ "LLaMA2-13B": {
688
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
689
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
690
+ },
691
+ "LLaMA2-70B": {
692
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
693
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
694
+ },
695
+ "LLaMA2-7B-Chat": {
696
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
697
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
698
+ },
699
+ "LLaMA2-13B-Chat": {
700
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
701
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
702
+ },
703
+ "LLaMA2-70B-Chat": {
704
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
705
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
706
+ },
707
+ },
708
+ template="llama2",
709
+ )
710
+
711
+
712
+ register_model_group(
713
+ models={
714
+ "LLaMA3-8B": {
715
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
716
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
717
+ },
718
+ "LLaMA3-70B": {
719
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
720
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
721
+ },
722
+ "LLaMA3-8B-Chat": {
723
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
724
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
725
+ },
726
+ "LLaMA3-70B-Chat": {
727
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
728
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
729
+ },
730
+ "LLaMA3-8B-Chinese-Chat": {
731
+ DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
732
+ DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
733
+ },
734
+ "LLaMA3-70B-Chinese-Chat": {
735
+ DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
736
+ },
737
+ },
738
+ template="llama3",
739
+ )
740
+
741
+
742
+ register_model_group(
743
+ models={
744
+ "LLaVA1.5-7B-Chat": {
745
+ DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
746
+ },
747
+ "LLaVA1.5-13B-Chat": {
748
+ DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
749
+ },
750
+ },
751
+ template="vicuna",
752
+ vision=True,
753
+ )
754
+
755
+
756
+ register_model_group(
757
+ models={
758
+ "MiniCPM-2B-SFT-Chat": {
759
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
760
+ DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
761
+ },
762
+ "MiniCPM-2B-DPO-Chat": {
763
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
764
+ DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
765
+ },
766
+ },
767
+ template="cpm",
768
+ )
769
+
770
+
771
+ register_model_group(
772
+ models={
773
+ "Mistral-7B-v0.1": {
774
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
775
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
776
+ },
777
+ "Mistral-7B-v0.1-Chat": {
778
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
779
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
780
+ },
781
+ "Mistral-7B-v0.2": {
782
+ DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
783
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
784
+ },
785
+ "Mistral-7B-v0.2-Chat": {
786
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
787
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
788
+ },
789
+ "Mistral-7B-v0.3": {
790
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
791
+ },
792
+ "Mistral-7B-v0.3-Chat": {
793
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
794
+ },
795
+ },
796
+ template="mistral",
797
+ )
798
+
799
+
800
+ register_model_group(
801
+ models={
802
+ "Mixtral-8x7B-v0.1": {
803
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
804
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
805
+ },
806
+ "Mixtral-8x7B-v0.1-Chat": {
807
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
808
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
809
+ },
810
+ "Mixtral-8x22B-v0.1": {
811
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
812
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
813
+ },
814
+ "Mixtral-8x22B-v0.1-Chat": {
815
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
816
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
817
+ },
818
+ },
819
+ template="mistral",
820
+ )
821
+
822
+
823
+ register_model_group(
824
+ models={
825
+ "OLMo-1B": {
826
+ DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
827
+ },
828
+ "OLMo-7B": {
829
+ DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
830
+ },
831
+ "OLMo-7B-Chat": {
832
+ DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
833
+ },
834
+ "OLMo-1.7-7B": {
835
+ DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
836
+ },
837
+ },
838
+ )
839
+
840
+
841
+ register_model_group(
842
+ models={
843
+ "OpenChat3.5-7B-Chat": {
844
+ DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
845
+ DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
846
+ }
847
+ },
848
+ template="openchat",
849
+ )
850
+
851
+
852
+ register_model_group(
853
+ models={
854
+ "OpenChat3.6-8B-Chat": {
855
+ DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
856
+ }
857
+ },
858
+ template="openchat-3.6",
859
+ )
860
+
861
+
862
+ register_model_group(
863
+ models={
864
+ "Orion-14B-Base": {
865
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
866
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
867
+ },
868
+ "Orion-14B-Chat": {
869
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
870
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
871
+ },
872
+ "Orion-14B-Long-Chat": {
873
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
874
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
875
+ },
876
+ "Orion-14B-RAG-Chat": {
877
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
878
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
879
+ },
880
+ "Orion-14B-Plugin-Chat": {
881
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
882
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
883
+ },
884
+ },
885
+ template="orion",
886
+ )
887
+
888
+
889
+ register_model_group(
890
+ models={
891
+ "PaliGemma-3B-pt-224": {
892
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
893
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
894
+ },
895
+ "PaliGemma-3B-pt-448": {
896
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
897
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
898
+ },
899
+ "PaliGemma-3B-pt-896": {
900
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
901
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
902
+ },
903
+ "PaliGemma-3B-mix-224": {
904
+ DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
905
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
906
+ },
907
+ "PaliGemma-3B-mix-448": {
908
+ DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
909
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
910
+ },
911
+ },
912
+ vision=True,
913
+ )
914
+
915
+
916
+ register_model_group(
917
+ models={
918
+ "Phi-1.5-1.3B": {
919
+ DownloadSource.DEFAULT: "microsoft/phi-1_5",
920
+ DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
921
+ },
922
+ "Phi-2-2.7B": {
923
+ DownloadSource.DEFAULT: "microsoft/phi-2",
924
+ DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
925
+ },
926
+ }
927
+ )
928
+
929
+
930
+ register_model_group(
931
+ models={
932
+ "Phi3-4B-4k-Chat": {
933
+ DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
934
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
935
+ },
936
+ "Phi3-4B-128k-Chat": {
937
+ DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
938
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
939
+ },
940
+ "Phi3-7B-8k-Chat": {
941
+ DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
942
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
943
+ },
944
+ "Phi3-7B-128k-Chat": {
945
+ DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
946
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
947
+ },
948
+ "Phi3-14B-8k-Chat": {
949
+ DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
950
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
951
+ },
952
+ "Phi3-14B-128k-Chat": {
953
+ DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
954
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
955
+ },
956
+ },
957
+ template="phi",
958
+ )
959
+
960
+
961
+ register_model_group(
962
+ models={
963
+ "Qwen-1.8B": {
964
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
965
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
966
+ },
967
+ "Qwen-7B": {
968
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B",
969
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
970
+ },
971
+ "Qwen-14B": {
972
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B",
973
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
974
+ },
975
+ "Qwen-72B": {
976
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B",
977
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
978
+ },
979
+ "Qwen-1.8B-Chat": {
980
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
981
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
982
+ },
983
+ "Qwen-7B-Chat": {
984
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
985
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
986
+ },
987
+ "Qwen-14B-Chat": {
988
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
989
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
990
+ },
991
+ "Qwen-72B-Chat": {
992
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
993
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
994
+ },
995
+ "Qwen-1.8B-int8-Chat": {
996
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
997
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
998
+ },
999
+ "Qwen-1.8B-int4-Chat": {
1000
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
1001
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
1002
+ },
1003
+ "Qwen-7B-int8-Chat": {
1004
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
1005
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
1006
+ },
1007
+ "Qwen-7B-int4-Chat": {
1008
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
1009
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
1010
+ },
1011
+ "Qwen-14B-int8-Chat": {
1012
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
1013
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
1014
+ },
1015
+ "Qwen-14B-int4-Chat": {
1016
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
1017
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
1018
+ },
1019
+ "Qwen-72B-int8-Chat": {
1020
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
1021
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
1022
+ },
1023
+ "Qwen-72B-int4-Chat": {
1024
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
1025
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
1026
+ },
1027
+ },
1028
+ template="qwen",
1029
+ )
1030
+
1031
+
1032
+ register_model_group(
1033
+ models={
1034
+ "Qwen1.5-0.5B": {
1035
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
1036
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
1037
+ },
1038
+ "Qwen1.5-1.8B": {
1039
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
1040
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
1041
+ },
1042
+ "Qwen1.5-4B": {
1043
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
1044
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
1045
+ },
1046
+ "Qwen1.5-7B": {
1047
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
1048
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
1049
+ },
1050
+ "Qwen1.5-14B": {
1051
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
1052
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
1053
+ },
1054
+ "Qwen1.5-32B": {
1055
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
1056
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
1057
+ },
1058
+ "Qwen1.5-72B": {
1059
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
1060
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
1061
+ },
1062
+ "Qwen1.5-110B": {
1063
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
1064
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
1065
+ },
1066
+ "Qwen1.5-MoE-A2.7B": {
1067
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
1068
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
1069
+ },
1070
+ "Qwen1.5-Code-7B": {
1071
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
1072
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
1073
+ },
1074
+ "Qwen1.5-0.5B-Chat": {
1075
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
1076
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
1077
+ },
1078
+ "Qwen1.5-1.8B-Chat": {
1079
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
1080
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
1081
+ },
1082
+ "Qwen1.5-4B-Chat": {
1083
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
1084
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
1085
+ },
1086
+ "Qwen1.5-7B-Chat": {
1087
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
1088
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
1089
+ },
1090
+ "Qwen1.5-14B-Chat": {
1091
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
1092
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
1093
+ },
1094
+ "Qwen1.5-32B-Chat": {
1095
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
1096
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
1097
+ },
1098
+ "Qwen1.5-72B-Chat": {
1099
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
1100
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
1101
+ },
1102
+ "Qwen1.5-110B-Chat": {
1103
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
1104
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
1105
+ },
1106
+ "Qwen1.5-MoE-A2.7B-Chat": {
1107
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
1108
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
1109
+ },
1110
+ "Qwen1.5-Code-7B-Chat": {
1111
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
1112
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
1113
+ },
1114
+ "Qwen1.5-0.5B-int8-Chat": {
1115
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
1116
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
1117
+ },
1118
+ "Qwen1.5-0.5B-int4-Chat": {
1119
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
1120
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
1121
+ },
1122
+ "Qwen1.5-1.8B-int8-Chat": {
1123
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
1124
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
1125
+ },
1126
+ "Qwen1.5-1.8B-int4-Chat": {
1127
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
1128
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
1129
+ },
1130
+ "Qwen1.5-4B-int8-Chat": {
1131
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
1132
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
1133
+ },
1134
+ "Qwen1.5-4B-int4-Chat": {
1135
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
1136
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
1137
+ },
1138
+ "Qwen1.5-7B-int8-Chat": {
1139
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
1140
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
1141
+ },
1142
+ "Qwen1.5-7B-int4-Chat": {
1143
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
1144
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
1145
+ },
1146
+ "Qwen1.5-14B-int8-Chat": {
1147
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
1148
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
1149
+ },
1150
+ "Qwen1.5-14B-int4-Chat": {
1151
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
1152
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
1153
+ },
1154
+ "Qwen1.5-32B-int4-Chat": {
1155
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
1156
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
1157
+ },
1158
+ "Qwen1.5-72B-int8-Chat": {
1159
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
1160
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
1161
+ },
1162
+ "Qwen1.5-72B-int4-Chat": {
1163
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
1164
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
1165
+ },
1166
+ "Qwen1.5-110B-int4-Chat": {
1167
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
1168
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
1169
+ },
1170
+ "Qwen1.5-MoE-A2.7B-int4-Chat": {
1171
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
1172
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
1173
+ },
1174
+ "Qwen1.5-Code-7B-int4-Chat": {
1175
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
1176
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
1177
+ },
1178
+ },
1179
+ template="qwen",
1180
+ )
1181
+
1182
+
1183
+ register_model_group(
1184
+ models={
1185
+ "Qwen2-0.5B": {
1186
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
1187
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
1188
+ },
1189
+ "Qwen2-1.5B": {
1190
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
1191
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
1192
+ },
1193
+ "Qwen2-7B": {
1194
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
1195
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
1196
+ },
1197
+ "Qwen2-72B": {
1198
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
1199
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
1200
+ },
1201
+ "Qwen2-MoE-57B": {
1202
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
1203
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
1204
+ },
1205
+ "Qwen2-0.5B-Chat": {
1206
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
1207
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
1208
+ },
1209
+ "Qwen2-1.5B-Chat": {
1210
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
1211
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
1212
+ },
1213
+ "Qwen2-7B-Chat": {
1214
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
1215
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
1216
+ },
1217
+ "Qwen2-72B-Chat": {
1218
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
1219
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
1220
+ },
1221
+ "Qwen2-MoE-57B-Chat": {
1222
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
1223
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
1224
+ },
1225
+ "Qwen2-0.5B-int8-Chat": {
1226
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
1227
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
1228
+ },
1229
+ "Qwen2-0.5B-int4-Chat": {
1230
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
1231
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
1232
+ },
1233
+ "Qwen2-1.5B-int8-Chat": {
1234
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
1235
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
1236
+ },
1237
+ "Qwen2-1.5B-int4-Chat": {
1238
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
1239
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
1240
+ },
1241
+ "Qwen2-7B-int8-Chat": {
1242
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
1243
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
1244
+ },
1245
+ "Qwen2-7B-int4-Chat": {
1246
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
1247
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
1248
+ },
1249
+ "Qwen2-72B-int8-Chat": {
1250
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
1251
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
1252
+ },
1253
+ "Qwen2-72B-int4-Chat": {
1254
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
1255
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
1256
+ },
1257
+ "Qwen2-MoE-57B-int4-Chat": {
1258
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
1259
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
1260
+ },
1261
+ },
1262
+ template="qwen",
1263
+ )
1264
+
1265
+
1266
+ register_model_group(
1267
+ models={
1268
+ "SOLAR-10.7B": {
1269
+ DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
1270
+ },
1271
+ "SOLAR-10.7B-Chat": {
1272
+ DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
1273
+ DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
1274
+ },
1275
+ },
1276
+ template="solar",
1277
+ )
1278
+
1279
+
1280
+ register_model_group(
1281
+ models={
1282
+ "Skywork-13B-Base": {
1283
+ DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
1284
+ DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
1285
+ }
1286
+ }
1287
+ )
1288
+
1289
+
1290
+ register_model_group(
1291
+ models={
1292
+ "StarCoder2-3B": {
1293
+ DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
1294
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
1295
+ },
1296
+ "StarCoder2-7B": {
1297
+ DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
1298
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
1299
+ },
1300
+ "StarCoder2-15B": {
1301
+ DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
1302
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
1303
+ },
1304
+ }
1305
+ )
1306
+
1307
+
1308
+ register_model_group(
1309
+ models={
1310
+ "TeleChat-1B-Chat": {
1311
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
1312
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
1313
+ },
1314
+ "TeleChat-7B-Chat": {
1315
+ DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
1316
+ DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
1317
+ },
1318
+ "TeleChat-12B-Chat": {
1319
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
1320
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
1321
+ },
1322
+ "TeleChat-12B-v2-Chat": {
1323
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
1324
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
1325
+ },
1326
+ },
1327
+ template="telechat",
1328
+ )
1329
+
1330
+
1331
+ register_model_group(
1332
+ models={
1333
+ "Vicuna1.5-7B-Chat": {
1334
+ DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
1335
+ DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
1336
+ },
1337
+ "Vicuna1.5-13B-Chat": {
1338
+ DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
1339
+ DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
1340
+ },
1341
+ },
1342
+ template="vicuna",
1343
+ )
1344
+
1345
+
1346
+ register_model_group(
1347
+ models={
1348
+ "XuanYuan-6B": {
1349
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
1350
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
1351
+ },
1352
+ "XuanYuan-70B": {
1353
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
1354
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
1355
+ },
1356
+ "XuanYuan-2-70B": {
1357
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
1358
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
1359
+ },
1360
+ "XuanYuan-6B-Chat": {
1361
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
1362
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
1363
+ },
1364
+ "XuanYuan-70B-Chat": {
1365
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
1366
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
1367
+ },
1368
+ "XuanYuan-2-70B-Chat": {
1369
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
1370
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
1371
+ },
1372
+ "XuanYuan-6B-int8-Chat": {
1373
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
1374
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
1375
+ },
1376
+ "XuanYuan-6B-int4-Chat": {
1377
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
1378
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
1379
+ },
1380
+ "XuanYuan-70B-int8-Chat": {
1381
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
1382
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
1383
+ },
1384
+ "XuanYuan-70B-int4-Chat": {
1385
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
1386
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
1387
+ },
1388
+ "XuanYuan-2-70B-int8-Chat": {
1389
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
1390
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
1391
+ },
1392
+ "XuanYuan-2-70B-int4-Chat": {
1393
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
1394
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
1395
+ },
1396
+ },
1397
+ template="xuanyuan",
1398
+ )
1399
+
1400
+
1401
+ register_model_group(
1402
+ models={
1403
+ "XVERSE-7B": {
1404
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B",
1405
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
1406
+ },
1407
+ "XVERSE-13B": {
1408
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B",
1409
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
1410
+ },
1411
+ "XVERSE-65B": {
1412
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B",
1413
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
1414
+ },
1415
+ "XVERSE-65B-2": {
1416
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
1417
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
1418
+ },
1419
+ "XVERSE-7B-Chat": {
1420
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
1421
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
1422
+ },
1423
+ "XVERSE-13B-Chat": {
1424
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
1425
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
1426
+ },
1427
+ "XVERSE-65B-Chat": {
1428
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
1429
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
1430
+ },
1431
+ "XVERSE-MoE-A4.2B": {
1432
+ DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
1433
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
1434
+ },
1435
+ "XVERSE-7B-int8-Chat": {
1436
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
1437
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
1438
+ },
1439
+ "XVERSE-7B-int4-Chat": {
1440
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
1441
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
1442
+ },
1443
+ "XVERSE-13B-int8-Chat": {
1444
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
1445
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
1446
+ },
1447
+ "XVERSE-13B-int4-Chat": {
1448
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
1449
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
1450
+ },
1451
+ "XVERSE-65B-int4-Chat": {
1452
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
1453
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
1454
+ },
1455
+ },
1456
+ template="xverse",
1457
+ )
1458
+
1459
+
1460
+ register_model_group(
1461
+ models={
1462
+ "Yayi-7B": {
1463
+ DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
1464
+ DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
1465
+ },
1466
+ "Yayi-13B": {
1467
+ DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
1468
+ DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
1469
+ },
1470
+ },
1471
+ template="yayi",
1472
+ )
1473
+
1474
+
1475
+ register_model_group(
1476
+ models={
1477
+ "Yi-6B": {
1478
+ DownloadSource.DEFAULT: "01-ai/Yi-6B",
1479
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B",
1480
+ },
1481
+ "Yi-9B": {
1482
+ DownloadSource.DEFAULT: "01-ai/Yi-9B",
1483
+ DownloadSource.MODELSCOPE: "01ai/Yi-9B",
1484
+ },
1485
+ "Yi-34B": {
1486
+ DownloadSource.DEFAULT: "01-ai/Yi-34B",
1487
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B",
1488
+ },
1489
+ "Yi-6B-Chat": {
1490
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
1491
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
1492
+ },
1493
+ "Yi-34B-Chat": {
1494
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
1495
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
1496
+ },
1497
+ "Yi-6B-int8-Chat": {
1498
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
1499
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
1500
+ },
1501
+ "Yi-6B-int4-Chat": {
1502
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
1503
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
1504
+ },
1505
+ "Yi-34B-int8-Chat": {
1506
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
1507
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
1508
+ },
1509
+ "Yi-34B-int4-Chat": {
1510
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
1511
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
1512
+ },
1513
+ "Yi-1.5-6B": {
1514
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
1515
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
1516
+ },
1517
+ "Yi-1.5-9B": {
1518
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
1519
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
1520
+ },
1521
+ "Yi-1.5-34B": {
1522
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
1523
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
1524
+ },
1525
+ "Yi-1.5-6B-Chat": {
1526
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
1527
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
1528
+ },
1529
+ "Yi-1.5-9B-Chat": {
1530
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
1531
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
1532
+ },
1533
+ "Yi-1.5-34B-Chat": {
1534
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
1535
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
1536
+ },
1537
+ },
1538
+ template="yi",
1539
+ )
1540
+
1541
+
1542
+ register_model_group(
1543
+ models={
1544
+ "YiVL-6B-Chat": {
1545
+ DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
1546
+ },
1547
+ "YiVL-34B-Chat": {
1548
+ DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
1549
+ },
1550
+ },
1551
+ template="yi_vl",
1552
+ vision=True,
1553
+ )
1554
+
1555
+
1556
+ register_model_group(
1557
+ models={
1558
+ "Yuan2-2B-Chat": {
1559
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
1560
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
1561
+ },
1562
+ "Yuan2-51B-Chat": {
1563
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
1564
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
1565
+ },
1566
+ "Yuan2-102B-Chat": {
1567
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
1568
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
1569
+ },
1570
+ },
1571
+ template="yuan",
1572
+ )
1573
+
1574
+
1575
+ register_model_group(
1576
+ models={
1577
+ "Zephyr-7B-Alpha-Chat": {
1578
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
1579
+ DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
1580
+ },
1581
+ "Zephyr-7B-Beta-Chat": {
1582
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
1583
+ DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
1584
+ },
1585
+ "Zephyr-141B-ORPO-Chat": {
1586
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
1587
+ },
1588
+ },
1589
+ template="zephyr",
1590
+ )
llama-factory/src/llamafactory/extras/env.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import platform
19
+
20
+ import accelerate
21
+ import datasets
22
+ import peft
23
+ import torch
24
+ import transformers
25
+ import trl
26
+ from transformers.utils import is_torch_cuda_available, is_torch_npu_available
27
+
28
+
29
+ VERSION = "0.8.4.dev0"
30
+
31
+
32
+ def print_env() -> None:
33
+ info = {
34
+ "`llamafactory` version": VERSION,
35
+ "Platform": platform.platform(),
36
+ "Python version": platform.python_version(),
37
+ "PyTorch version": torch.__version__,
38
+ "Transformers version": transformers.__version__,
39
+ "Datasets version": datasets.__version__,
40
+ "Accelerate version": accelerate.__version__,
41
+ "PEFT version": peft.__version__,
42
+ "TRL version": trl.__version__,
43
+ }
44
+
45
+ if is_torch_cuda_available():
46
+ info["PyTorch version"] += " (GPU)"
47
+ info["GPU type"] = torch.cuda.get_device_name()
48
+
49
+ if is_torch_npu_available():
50
+ info["PyTorch version"] += " (NPU)"
51
+ info["NPU type"] = torch.npu.get_device_name()
52
+ info["CANN version"] = torch.version.cann
53
+
54
+ try:
55
+ import deepspeed # type: ignore
56
+
57
+ info["DeepSpeed version"] = deepspeed.__version__
58
+ except Exception:
59
+ pass
60
+
61
+ try:
62
+ import bitsandbytes
63
+
64
+ info["Bitsandbytes version"] = bitsandbytes.__version__
65
+ except Exception:
66
+ pass
67
+
68
+ try:
69
+ import vllm
70
+
71
+ info["vLLM version"] = vllm.__version__
72
+ except Exception:
73
+ pass
74
+
75
+ print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
llama-factory/src/llamafactory/extras/logging.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import sys
18
+ from concurrent.futures import ThreadPoolExecutor
19
+
20
+ from .constants import RUNNING_LOG
21
+
22
+
23
+ class LoggerHandler(logging.Handler):
24
+ r"""
25
+ Logger handler used in Web UI.
26
+ """
27
+
28
+ def __init__(self, output_dir: str) -> None:
29
+ super().__init__()
30
+ formatter = logging.Formatter(
31
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
32
+ )
33
+ self.setLevel(logging.INFO)
34
+ self.setFormatter(formatter)
35
+
36
+ os.makedirs(output_dir, exist_ok=True)
37
+ self.running_log = os.path.join(output_dir, RUNNING_LOG)
38
+ if os.path.exists(self.running_log):
39
+ os.remove(self.running_log)
40
+
41
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
42
+
43
+ def _write_log(self, log_entry: str) -> None:
44
+ with open(self.running_log, "a", encoding="utf-8") as f:
45
+ f.write(log_entry + "\n\n")
46
+
47
+ def emit(self, record) -> None:
48
+ if record.name == "httpx":
49
+ return
50
+
51
+ log_entry = self.format(record)
52
+ self.thread_pool.submit(self._write_log, log_entry)
53
+
54
+ def close(self) -> None:
55
+ self.thread_pool.shutdown(wait=True)
56
+ return super().close()
57
+
58
+
59
+ def get_logger(name: str) -> logging.Logger:
60
+ r"""
61
+ Gets a standard logger with a stream hander to stdout.
62
+ """
63
+ formatter = logging.Formatter(
64
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
65
+ )
66
+ handler = logging.StreamHandler(sys.stdout)
67
+ handler.setFormatter(formatter)
68
+
69
+ logger = logging.getLogger(name)
70
+ logger.setLevel(logging.INFO)
71
+ logger.addHandler(handler)
72
+
73
+ return logger
74
+
75
+
76
+ def reset_logging() -> None:
77
+ r"""
78
+ Removes basic config of root logger. (unused in script)
79
+ """
80
+ root = logging.getLogger()
81
+ list(map(root.removeHandler, root.handlers))
82
+ list(map(root.removeFilter, root.filters))
llama-factory/src/llamafactory/extras/misc.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's PEFT library.
4
+ # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import gc
19
+ import os
20
+ from typing import TYPE_CHECKING, Tuple, Union
21
+
22
+ import torch
23
+ import transformers.dynamic_module_utils
24
+ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
25
+ from transformers.dynamic_module_utils import get_relative_imports
26
+ from transformers.utils import (
27
+ is_torch_bf16_gpu_available,
28
+ is_torch_cuda_available,
29
+ is_torch_mps_available,
30
+ is_torch_npu_available,
31
+ is_torch_xpu_available,
32
+ )
33
+ from transformers.utils.versions import require_version
34
+
35
+ from .logging import get_logger
36
+
37
+
38
+ _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
39
+ try:
40
+ _is_bf16_available = is_torch_bf16_gpu_available()
41
+ except Exception:
42
+ _is_bf16_available = False
43
+
44
+
45
+ if TYPE_CHECKING:
46
+ from numpy.typing import NDArray
47
+
48
+ from ..hparams import ModelArguments
49
+
50
+
51
+ logger = get_logger(__name__)
52
+
53
+
54
+ class AverageMeter:
55
+ r"""
56
+ Computes and stores the average and current value.
57
+ """
58
+
59
+ def __init__(self):
60
+ self.reset()
61
+
62
+ def reset(self):
63
+ self.val = 0
64
+ self.avg = 0
65
+ self.sum = 0
66
+ self.count = 0
67
+
68
+ def update(self, val, n=1):
69
+ self.val = val
70
+ self.sum += val * n
71
+ self.count += n
72
+ self.avg = self.sum / self.count
73
+
74
+
75
+ def check_dependencies() -> None:
76
+ r"""
77
+ Checks the version of the required packages.
78
+ """
79
+ if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
80
+ logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
81
+ else:
82
+ require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
83
+ require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
84
+ require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
85
+ require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
86
+ require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
87
+
88
+
89
+ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
90
+ r"""
91
+ Returns the number of trainable parameters and number of all parameters in the model.
92
+ """
93
+ trainable_params, all_param = 0, 0
94
+ for param in model.parameters():
95
+ num_params = param.numel()
96
+ # if using DS Zero 3 and the weights are initialized empty
97
+ if num_params == 0 and hasattr(param, "ds_numel"):
98
+ num_params = param.ds_numel
99
+
100
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
101
+ if param.__class__.__name__ == "Params4bit":
102
+ if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
103
+ num_bytes = param.quant_storage.itemsize
104
+ elif hasattr(param, "element_size"): # for older pytorch version
105
+ num_bytes = param.element_size()
106
+ else:
107
+ num_bytes = 1
108
+
109
+ num_params = num_params * 2 * num_bytes
110
+
111
+ all_param += num_params
112
+ if param.requires_grad:
113
+ trainable_params += num_params
114
+
115
+ return trainable_params, all_param
116
+
117
+
118
+ def get_current_device() -> "torch.device":
119
+ r"""
120
+ Gets the current available device.
121
+ """
122
+ if is_torch_xpu_available():
123
+ device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
124
+ elif is_torch_npu_available():
125
+ device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
126
+ elif is_torch_mps_available():
127
+ device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
128
+ elif is_torch_cuda_available():
129
+ device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
130
+ else:
131
+ device = "cpu"
132
+
133
+ return torch.device(device)
134
+
135
+
136
+ def get_device_count() -> int:
137
+ r"""
138
+ Gets the number of available GPU or NPU devices.
139
+ """
140
+ if is_torch_npu_available():
141
+ return torch.npu.device_count()
142
+ elif is_torch_cuda_available():
143
+ return torch.cuda.device_count()
144
+ else:
145
+ return 0
146
+
147
+
148
+ def get_logits_processor() -> "LogitsProcessorList":
149
+ r"""
150
+ Gets logits processor that removes NaN and Inf logits.
151
+ """
152
+ logits_processor = LogitsProcessorList()
153
+ logits_processor.append(InfNanRemoveLogitsProcessor())
154
+ return logits_processor
155
+
156
+
157
+ def has_tokenized_data(path: "os.PathLike") -> bool:
158
+ r"""
159
+ Checks if the path has a tokenized dataset.
160
+ """
161
+ return os.path.isdir(path) and len(os.listdir(path)) > 0
162
+
163
+
164
+ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
165
+ r"""
166
+ Infers the optimal dtype according to the model_dtype and device compatibility.
167
+ """
168
+ if _is_bf16_available and model_dtype == torch.bfloat16:
169
+ return torch.bfloat16
170
+ elif _is_fp16_available:
171
+ return torch.float16
172
+ else:
173
+ return torch.float32
174
+
175
+
176
+ def is_gpu_or_npu_available() -> bool:
177
+ r"""
178
+ Checks if the GPU or NPU is available.
179
+ """
180
+ return is_torch_npu_available() or is_torch_cuda_available()
181
+
182
+
183
+ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
184
+ if isinstance(inputs, torch.Tensor):
185
+ inputs = inputs.cpu()
186
+ if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
187
+ inputs = inputs.to(torch.float32)
188
+
189
+ inputs = inputs.numpy()
190
+
191
+ return inputs
192
+
193
+
194
+ def skip_check_imports() -> None:
195
+ if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
196
+ transformers.dynamic_module_utils.check_imports = get_relative_imports
197
+
198
+
199
+ def torch_gc() -> None:
200
+ r"""
201
+ Collects GPU or NPU memory.
202
+ """
203
+ gc.collect()
204
+ if is_torch_xpu_available():
205
+ torch.xpu.empty_cache()
206
+ elif is_torch_npu_available():
207
+ torch.npu.empty_cache()
208
+ elif is_torch_mps_available():
209
+ torch.mps.empty_cache()
210
+ elif is_torch_cuda_available():
211
+ torch.cuda.empty_cache()
212
+
213
+
214
+ def try_download_model_from_ms(model_args: "ModelArguments") -> str:
215
+ if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
216
+ return model_args.model_name_or_path
217
+
218
+ try:
219
+ from modelscope import snapshot_download
220
+
221
+ revision = "master" if model_args.model_revision == "main" else model_args.model_revision
222
+ return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
223
+ except ImportError:
224
+ raise ImportError("Please install modelscope via `pip install modelscope -U`")
225
+
226
+
227
+ def use_modelscope() -> bool:
228
+ return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
llama-factory/src/llamafactory/extras/packages.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import importlib.metadata
19
+ import importlib.util
20
+ from functools import lru_cache
21
+ from typing import TYPE_CHECKING
22
+
23
+ from packaging import version
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from packaging.version import Version
28
+
29
+
30
+ def _is_package_available(name: str) -> bool:
31
+ return importlib.util.find_spec(name) is not None
32
+
33
+
34
+ def _get_package_version(name: str) -> "Version":
35
+ try:
36
+ return version.parse(importlib.metadata.version(name))
37
+ except Exception:
38
+ return version.parse("0.0.0")
39
+
40
+
41
+ def is_fastapi_available():
42
+ return _is_package_available("fastapi")
43
+
44
+
45
+ def is_galore_available():
46
+ return _is_package_available("galore_torch")
47
+
48
+
49
+ def is_gradio_available():
50
+ return _is_package_available("gradio")
51
+
52
+
53
+ def is_matplotlib_available():
54
+ return _is_package_available("matplotlib")
55
+
56
+
57
+ def is_pillow_available():
58
+ return _is_package_available("PIL")
59
+
60
+
61
+ def is_requests_available():
62
+ return _is_package_available("requests")
63
+
64
+
65
+ def is_rouge_available():
66
+ return _is_package_available("rouge_chinese")
67
+
68
+
69
+ def is_starlette_available():
70
+ return _is_package_available("sse_starlette")
71
+
72
+
73
+ def is_uvicorn_available():
74
+ return _is_package_available("uvicorn")
75
+
76
+
77
+ def is_vllm_available():
78
+ return _is_package_available("vllm")
79
+
80
+
81
+ @lru_cache
82
+ def is_vllm_version_greater_than_0_5():
83
+ return _get_package_version("vllm") >= version.parse("0.5.0")
84
+
85
+
86
+ @lru_cache
87
+ def is_vllm_version_greater_than_0_5_1():
88
+ return _get_package_version("vllm") >= version.parse("0.5.1")
llama-factory/src/llamafactory/extras/ploting.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import math
17
+ import os
18
+ from typing import Any, Dict, List
19
+
20
+ from transformers.trainer import TRAINER_STATE_NAME
21
+
22
+ from .logging import get_logger
23
+ from .packages import is_matplotlib_available
24
+
25
+
26
+ if is_matplotlib_available():
27
+ import matplotlib.figure
28
+ import matplotlib.pyplot as plt
29
+
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ def smooth(scalars: List[float]) -> List[float]:
35
+ r"""
36
+ EMA implementation according to TensorBoard.
37
+ """
38
+ if len(scalars) == 0:
39
+ return []
40
+
41
+ last = scalars[0]
42
+ smoothed = []
43
+ weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
44
+ for next_val in scalars:
45
+ smoothed_val = last * weight + (1 - weight) * next_val
46
+ smoothed.append(smoothed_val)
47
+ last = smoothed_val
48
+ return smoothed
49
+
50
+
51
+ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
52
+ r"""
53
+ Plots loss curves in LlamaBoard.
54
+ """
55
+ plt.close("all")
56
+ plt.switch_backend("agg")
57
+ fig = plt.figure()
58
+ ax = fig.add_subplot(111)
59
+ steps, losses = [], []
60
+ for log in trainer_log:
61
+ if log.get("loss", None):
62
+ steps.append(log["current_steps"])
63
+ losses.append(log["loss"])
64
+
65
+ ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
66
+ ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
67
+ ax.legend()
68
+ ax.set_xlabel("step")
69
+ ax.set_ylabel("loss")
70
+ return fig
71
+
72
+
73
+ def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
74
+ r"""
75
+ Plots loss curves and saves the image.
76
+ """
77
+ plt.switch_backend("agg")
78
+ with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
79
+ data = json.load(f)
80
+
81
+ for key in keys:
82
+ steps, metrics = [], []
83
+ for i in range(len(data["log_history"])):
84
+ if key in data["log_history"][i]:
85
+ steps.append(data["log_history"][i]["step"])
86
+ metrics.append(data["log_history"][i][key])
87
+
88
+ if len(metrics) == 0:
89
+ logger.warning(f"No metric {key} to plot.")
90
+ continue
91
+
92
+ plt.figure()
93
+ plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
94
+ plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
95
+ plt.title("training {} of {}".format(key, save_dictionary))
96
+ plt.xlabel("step")
97
+ plt.ylabel(key)
98
+ plt.legend()
99
+ figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
100
+ plt.savefig(figure_path, format="png", dpi=100)
101
+ print("Figure saved at:", figure_path)
llama-factory/src/llamafactory/hparams/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .data_args import DataArguments
16
+ from .evaluation_args import EvaluationArguments
17
+ from .finetuning_args import FinetuningArguments
18
+ from .generating_args import GeneratingArguments
19
+ from .model_args import ModelArguments
20
+ from .parser import get_eval_args, get_infer_args, get_train_args
21
+
22
+
23
+ __all__ = [
24
+ "DataArguments",
25
+ "EvaluationArguments",
26
+ "FinetuningArguments",
27
+ "GeneratingArguments",
28
+ "ModelArguments",
29
+ "get_eval_args",
30
+ "get_infer_args",
31
+ "get_train_args",
32
+ ]
llama-factory/src/llamafactory/hparams/data_args.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Literal, Optional
20
+
21
+
22
+ @dataclass
23
+ class DataArguments:
24
+ r"""
25
+ Arguments pertaining to what data we are going to input our model for training and evaluation.
26
+ """
27
+
28
+ template: Optional[str] = field(
29
+ default=None,
30
+ metadata={"help": "Which template to use for constructing prompts in training and inference."},
31
+ )
32
+ dataset: Optional[str] = field(
33
+ default=None,
34
+ metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
35
+ )
36
+ eval_dataset: Optional[str] = field(
37
+ default=None,
38
+ metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
39
+ )
40
+ dataset_dir: str = field(
41
+ default="data",
42
+ metadata={"help": "Path to the folder containing the datasets."},
43
+ )
44
+ cutoff_len: int = field(
45
+ default=1024,
46
+ metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
47
+ )
48
+ train_on_prompt: bool = field(
49
+ default=False,
50
+ metadata={"help": "Whether or not to disable the mask on the prompt."},
51
+ )
52
+ mask_history: bool = field(
53
+ default=False,
54
+ metadata={"help": "Whether or not to mask the history and train on the last turn only."},
55
+ )
56
+ streaming: bool = field(
57
+ default=False,
58
+ metadata={"help": "Enable dataset streaming."},
59
+ )
60
+ buffer_size: int = field(
61
+ default=16384,
62
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
63
+ )
64
+ mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
65
+ default="concat",
66
+ metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
67
+ )
68
+ interleave_probs: Optional[str] = field(
69
+ default=None,
70
+ metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
71
+ )
72
+ overwrite_cache: bool = field(
73
+ default=False,
74
+ metadata={"help": "Overwrite the cached training and evaluation sets."},
75
+ )
76
+ preprocessing_num_workers: Optional[int] = field(
77
+ default=None,
78
+ metadata={"help": "The number of processes to use for the pre-processing."},
79
+ )
80
+ max_samples: Optional[int] = field(
81
+ default=None,
82
+ metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
83
+ )
84
+ eval_num_beams: Optional[int] = field(
85
+ default=None,
86
+ metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
87
+ )
88
+ ignore_pad_token_for_loss: bool = field(
89
+ default=True,
90
+ metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
91
+ )
92
+ val_size: float = field(
93
+ default=0.0,
94
+ metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
95
+ )
96
+ packing: Optional[bool] = field(
97
+ default=None,
98
+ metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
99
+ )
100
+ neat_packing: bool = field(
101
+ default=False,
102
+ metadata={"help": "Enable sequence packing without cross-attention."},
103
+ )
104
+ tool_format: Optional[str] = field(
105
+ default=None,
106
+ metadata={"help": "Tool format to use for constructing function calling examples."},
107
+ )
108
+ tokenized_path: Optional[str] = field(
109
+ default=None,
110
+ metadata={"help": "Path to save or load the tokenized datasets."},
111
+ )
112
+
113
+ def __post_init__(self):
114
+ def split_arg(arg):
115
+ if isinstance(arg, str):
116
+ return [item.strip() for item in arg.split(",")]
117
+ return arg
118
+
119
+ self.dataset = split_arg(self.dataset)
120
+ self.eval_dataset = split_arg(self.eval_dataset)
121
+
122
+ if self.dataset is None and self.val_size > 1e-6:
123
+ raise ValueError("Cannot specify `val_size` if `dataset` is None.")
124
+
125
+ if self.eval_dataset is not None and self.val_size > 1e-6:
126
+ raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
127
+
128
+ if self.interleave_probs is not None:
129
+ if self.mix_strategy == "concat":
130
+ raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
131
+
132
+ self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
133
+ if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
134
+ raise ValueError("The length of dataset and interleave probs should be identical.")
135
+
136
+ if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
137
+ raise ValueError("The length of eval dataset and interleave probs should be identical.")
138
+
139
+ if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
140
+ raise ValueError("Streaming mode should have an integer val size.")
141
+
142
+ if self.streaming and self.max_samples is not None:
143
+ raise ValueError("`max_samples` is incompatible with `streaming`.")
llama-factory/src/llamafactory/hparams/evaluation_args.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from dataclasses import dataclass, field
17
+ from typing import Literal, Optional
18
+
19
+ from datasets import DownloadMode
20
+
21
+
22
+ @dataclass
23
+ class EvaluationArguments:
24
+ r"""
25
+ Arguments pertaining to specify the evaluation parameters.
26
+ """
27
+
28
+ task: str = field(
29
+ metadata={"help": "Name of the evaluation task."},
30
+ )
31
+ task_dir: str = field(
32
+ default="evaluation",
33
+ metadata={"help": "Path to the folder containing the evaluation datasets."},
34
+ )
35
+ batch_size: int = field(
36
+ default=4,
37
+ metadata={"help": "The batch size per GPU for evaluation."},
38
+ )
39
+ seed: int = field(
40
+ default=42,
41
+ metadata={"help": "Random seed to be used with data loaders."},
42
+ )
43
+ lang: Literal["en", "zh"] = field(
44
+ default="en",
45
+ metadata={"help": "Language used at evaluation."},
46
+ )
47
+ n_shot: int = field(
48
+ default=5,
49
+ metadata={"help": "Number of examplars for few-shot learning."},
50
+ )
51
+ save_dir: Optional[str] = field(
52
+ default=None,
53
+ metadata={"help": "Path to save the evaluation results."},
54
+ )
55
+ download_mode: DownloadMode = field(
56
+ default=DownloadMode.REUSE_DATASET_IF_EXISTS,
57
+ metadata={"help": "Download mode used for the evaluation datasets."},
58
+ )
59
+
60
+ def __post_init__(self):
61
+ if self.save_dir is not None and os.path.exists(self.save_dir):
62
+ raise ValueError("`save_dir` already exists, use another one.")
llama-factory/src/llamafactory/hparams/finetuning_args.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import List, Literal, Optional
17
+
18
+
19
+ @dataclass
20
+ class FreezeArguments:
21
+ r"""
22
+ Arguments pertaining to the freeze (partial-parameter) training.
23
+ """
24
+
25
+ freeze_trainable_layers: int = field(
26
+ default=2,
27
+ metadata={
28
+ "help": (
29
+ "The number of trainable layers for freeze (partial-parameter) fine-tuning. "
30
+ "Positive numbers mean the last n layers are set as trainable, "
31
+ "negative numbers mean the first n layers are set as trainable."
32
+ )
33
+ },
34
+ )
35
+ freeze_trainable_modules: str = field(
36
+ default="all",
37
+ metadata={
38
+ "help": (
39
+ "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
40
+ "Use commas to separate multiple modules. "
41
+ "Use `all` to specify all the available modules."
42
+ )
43
+ },
44
+ )
45
+ freeze_extra_modules: Optional[str] = field(
46
+ default=None,
47
+ metadata={
48
+ "help": (
49
+ "Name(s) of modules apart from hidden layers to be set as trainable "
50
+ "for freeze (partial-parameter) fine-tuning. "
51
+ "Use commas to separate multiple modules."
52
+ )
53
+ },
54
+ )
55
+
56
+
57
+ @dataclass
58
+ class LoraArguments:
59
+ r"""
60
+ Arguments pertaining to the LoRA training.
61
+ """
62
+
63
+ additional_target: Optional[str] = field(
64
+ default=None,
65
+ metadata={
66
+ "help": (
67
+ "Name(s) of modules apart from LoRA layers to be set as trainable "
68
+ "and saved in the final checkpoint. "
69
+ "Use commas to separate multiple modules."
70
+ )
71
+ },
72
+ )
73
+ lora_alpha: Optional[int] = field(
74
+ default=None,
75
+ metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
76
+ )
77
+ lora_dropout: float = field(
78
+ default=0.0,
79
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."},
80
+ )
81
+ lora_rank: int = field(
82
+ default=8,
83
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
84
+ )
85
+ lora_target: str = field(
86
+ default="all",
87
+ metadata={
88
+ "help": (
89
+ "Name(s) of target modules to apply LoRA. "
90
+ "Use commas to separate multiple modules. "
91
+ "Use `all` to specify all the linear modules."
92
+ )
93
+ },
94
+ )
95
+ loraplus_lr_ratio: Optional[float] = field(
96
+ default=None,
97
+ metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
98
+ )
99
+ loraplus_lr_embedding: float = field(
100
+ default=1e-6,
101
+ metadata={"help": "LoRA plus learning rate for lora embedding layers."},
102
+ )
103
+ use_rslora: bool = field(
104
+ default=False,
105
+ metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
106
+ )
107
+ use_dora: bool = field(
108
+ default=False,
109
+ metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
110
+ )
111
+ pissa_init: bool = field(
112
+ default=False,
113
+ metadata={"help": "Whether or not to initialize a PiSSA adapter."},
114
+ )
115
+ pissa_iter: int = field(
116
+ default=16,
117
+ metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
118
+ )
119
+ pissa_convert: bool = field(
120
+ default=False,
121
+ metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
122
+ )
123
+ create_new_adapter: bool = field(
124
+ default=False,
125
+ metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
126
+ )
127
+
128
+
129
+ @dataclass
130
+ class RLHFArguments:
131
+ r"""
132
+ Arguments pertaining to the PPO, DPO and KTO training.
133
+ """
134
+
135
+ pref_beta: float = field(
136
+ default=0.1,
137
+ metadata={"help": "The beta parameter in the preference loss."},
138
+ )
139
+ pref_ftx: float = field(
140
+ default=0.0,
141
+ metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
142
+ )
143
+ pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
144
+ default="sigmoid",
145
+ metadata={"help": "The type of DPO loss to use."},
146
+ )
147
+ dpo_label_smoothing: float = field(
148
+ default=0.0,
149
+ metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
150
+ )
151
+ kto_chosen_weight: float = field(
152
+ default=1.0,
153
+ metadata={"help": "The weight factor of the desirable losses in KTO training."},
154
+ )
155
+ kto_rejected_weight: float = field(
156
+ default=1.0,
157
+ metadata={"help": "The weight factor of the undesirable losses in KTO training."},
158
+ )
159
+ simpo_gamma: float = field(
160
+ default=0.5,
161
+ metadata={"help": "The target reward margin term in SimPO loss."},
162
+ )
163
+ ppo_buffer_size: int = field(
164
+ default=1,
165
+ metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
166
+ )
167
+ ppo_epochs: int = field(
168
+ default=4,
169
+ metadata={"help": "The number of epochs to perform in a PPO optimization step."},
170
+ )
171
+ ppo_score_norm: bool = field(
172
+ default=False,
173
+ metadata={"help": "Use score normalization in PPO training."},
174
+ )
175
+ ppo_target: float = field(
176
+ default=6.0,
177
+ metadata={"help": "Target KL value for adaptive KL control in PPO training."},
178
+ )
179
+ ppo_whiten_rewards: bool = field(
180
+ default=False,
181
+ metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
182
+ )
183
+ ref_model: Optional[str] = field(
184
+ default=None,
185
+ metadata={"help": "Path to the reference model used for the PPO or DPO training."},
186
+ )
187
+ ref_model_adapters: Optional[str] = field(
188
+ default=None,
189
+ metadata={"help": "Path to the adapters of the reference model."},
190
+ )
191
+ ref_model_quantization_bit: Optional[int] = field(
192
+ default=None,
193
+ metadata={"help": "The number of bits to quantize the reference model."},
194
+ )
195
+ reward_model: Optional[str] = field(
196
+ default=None,
197
+ metadata={"help": "Path to the reward model used for the PPO training."},
198
+ )
199
+ reward_model_adapters: Optional[str] = field(
200
+ default=None,
201
+ metadata={"help": "Path to the adapters of the reward model."},
202
+ )
203
+ reward_model_quantization_bit: Optional[int] = field(
204
+ default=None,
205
+ metadata={"help": "The number of bits to quantize the reward model."},
206
+ )
207
+ reward_model_type: Literal["lora", "full", "api"] = field(
208
+ default="lora",
209
+ metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
210
+ )
211
+
212
+
213
+ @dataclass
214
+ class GaloreArguments:
215
+ r"""
216
+ Arguments pertaining to the GaLore algorithm.
217
+ """
218
+
219
+ use_galore: bool = field(
220
+ default=False,
221
+ metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
222
+ )
223
+ galore_target: str = field(
224
+ default="all",
225
+ metadata={
226
+ "help": (
227
+ "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
228
+ "Use `all` to specify all the linear modules."
229
+ )
230
+ },
231
+ )
232
+ galore_rank: int = field(
233
+ default=16,
234
+ metadata={"help": "The rank of GaLore gradients."},
235
+ )
236
+ galore_update_interval: int = field(
237
+ default=200,
238
+ metadata={"help": "Number of steps to update the GaLore projection."},
239
+ )
240
+ galore_scale: float = field(
241
+ default=0.25,
242
+ metadata={"help": "GaLore scaling coefficient."},
243
+ )
244
+ galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
245
+ default="std",
246
+ metadata={"help": "Type of GaLore projection."},
247
+ )
248
+ galore_layerwise: bool = field(
249
+ default=False,
250
+ metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
251
+ )
252
+
253
+
254
+ @dataclass
255
+ class BAdamArgument:
256
+ r"""
257
+ Arguments pertaining to the BAdam optimizer.
258
+ """
259
+
260
+ use_badam: bool = field(
261
+ default=False,
262
+ metadata={"help": "Whether or not to use the BAdam optimizer."},
263
+ )
264
+ badam_mode: Literal["layer", "ratio"] = field(
265
+ default="layer",
266
+ metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
267
+ )
268
+ badam_start_block: Optional[int] = field(
269
+ default=None,
270
+ metadata={"help": "The starting block index for layer-wise BAdam."},
271
+ )
272
+ badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
273
+ default="ascending",
274
+ metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
275
+ )
276
+ badam_switch_interval: Optional[int] = field(
277
+ default=50,
278
+ metadata={
279
+ "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
280
+ },
281
+ )
282
+ badam_update_ratio: float = field(
283
+ default=0.05,
284
+ metadata={"help": "The ratio of the update for ratio-wise BAdam."},
285
+ )
286
+ badam_mask_mode: Literal["adjacent", "scatter"] = field(
287
+ default="adjacent",
288
+ metadata={
289
+ "help": (
290
+ "The mode of the mask for BAdam optimizer. "
291
+ "`adjacent` means that the trainable parameters are adjacent to each other, "
292
+ "`scatter` means that trainable parameters are randomly choosed from the weight."
293
+ )
294
+ },
295
+ )
296
+ badam_verbose: int = field(
297
+ default=0,
298
+ metadata={
299
+ "help": (
300
+ "The verbosity level of BAdam optimizer. "
301
+ "0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
302
+ )
303
+ },
304
+ )
305
+
306
+
307
+ @dataclass
308
+ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
309
+ r"""
310
+ Arguments pertaining to which techniques we are going to fine-tuning with.
311
+ """
312
+
313
+ pure_bf16: bool = field(
314
+ default=False,
315
+ metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
316
+ )
317
+ stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
318
+ default="sft",
319
+ metadata={"help": "Which stage will be performed in training."},
320
+ )
321
+ finetuning_type: Literal["lora", "freeze", "full"] = field(
322
+ default="lora",
323
+ metadata={"help": "Which fine-tuning method to use."},
324
+ )
325
+ use_llama_pro: bool = field(
326
+ default=False,
327
+ metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
328
+ )
329
+ freeze_vision_tower: bool = field(
330
+ default=True,
331
+ metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
332
+ )
333
+ train_mm_proj_only: bool = field(
334
+ default=False,
335
+ metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
336
+ )
337
+ compute_accuracy: bool = field(
338
+ default=False,
339
+ metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
340
+ )
341
+ plot_loss: bool = field(
342
+ default=False,
343
+ metadata={"help": "Whether or not to save the training loss curves."},
344
+ )
345
+
346
+ def __post_init__(self):
347
+ def split_arg(arg):
348
+ if isinstance(arg, str):
349
+ return [item.strip() for item in arg.split(",")]
350
+ return arg
351
+
352
+ self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
353
+ self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
354
+ self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
355
+ self.lora_target: List[str] = split_arg(self.lora_target)
356
+ self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
357
+ self.galore_target: List[str] = split_arg(self.galore_target)
358
+ self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
359
+ self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
360
+
361
+ assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
362
+ assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
363
+ assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
364
+
365
+ if self.stage == "ppo" and self.reward_model is None:
366
+ raise ValueError("`reward_model` is necessary for PPO training.")
367
+
368
+ if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
369
+ raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
370
+
371
+ if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
372
+ raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
373
+
374
+ if self.use_llama_pro and self.finetuning_type == "full":
375
+ raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
376
+
377
+ if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
378
+ raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
379
+
380
+ if self.use_galore and self.use_badam:
381
+ raise ValueError("Cannot use GaLore with BAdam together.")
382
+
383
+ if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
384
+ raise ValueError("Cannot use PiSSA for current training stage.")
385
+
386
+ if self.train_mm_proj_only and self.finetuning_type != "full":
387
+ raise ValueError("`train_mm_proj_only` is only valid for full training.")
388
+
389
+ if self.finetuning_type != "lora":
390
+ if self.loraplus_lr_ratio is not None:
391
+ raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
392
+
393
+ if self.use_rslora:
394
+ raise ValueError("`use_rslora` is only valid for LoRA training.")
395
+
396
+ if self.use_dora:
397
+ raise ValueError("`use_dora` is only valid for LoRA training.")
398
+
399
+ if self.pissa_init:
400
+ raise ValueError("`pissa_init` is only valid for LoRA training.")
llama-factory/src/llamafactory/hparams/generating_args.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import asdict, dataclass, field
16
+ from typing import Any, Dict, Optional
17
+
18
+
19
+ @dataclass
20
+ class GeneratingArguments:
21
+ r"""
22
+ Arguments pertaining to specify the decoding parameters.
23
+ """
24
+
25
+ do_sample: bool = field(
26
+ default=True,
27
+ metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
28
+ )
29
+ temperature: float = field(
30
+ default=0.95,
31
+ metadata={"help": "The value used to modulate the next token probabilities."},
32
+ )
33
+ top_p: float = field(
34
+ default=0.7,
35
+ metadata={
36
+ "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
37
+ },
38
+ )
39
+ top_k: int = field(
40
+ default=50,
41
+ metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
42
+ )
43
+ num_beams: int = field(
44
+ default=1,
45
+ metadata={"help": "Number of beams for beam search. 1 means no beam search."},
46
+ )
47
+ max_length: int = field(
48
+ default=1024,
49
+ metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
50
+ )
51
+ max_new_tokens: int = field(
52
+ default=1024,
53
+ metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
54
+ )
55
+ repetition_penalty: float = field(
56
+ default=1.0,
57
+ metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
58
+ )
59
+ length_penalty: float = field(
60
+ default=1.0,
61
+ metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
62
+ )
63
+ default_system: Optional[str] = field(
64
+ default=None,
65
+ metadata={"help": "Default system message to use in chat completion."},
66
+ )
67
+
68
+ def to_dict(self) -> Dict[str, Any]:
69
+ args = asdict(self)
70
+ if args.get("max_new_tokens", -1) > 0:
71
+ args.pop("max_length", None)
72
+ else:
73
+ args.pop("max_new_tokens", None)
74
+ return args