Spaces:
Build error
Build error
extra setup so llamafactory can run
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- llama-factory/README.md +645 -0
- llama-factory/pyproject.toml +33 -0
- llama-factory/requirements.txt +21 -0
- llama-factory/setup.py +92 -0
- llama-factory/src/api.py +33 -0
- llama-factory/src/llamafactory/__init__.py +41 -0
- llama-factory/src/llamafactory/api/__init__.py +0 -0
- llama-factory/src/llamafactory/api/app.py +122 -0
- llama-factory/src/llamafactory/api/chat.py +237 -0
- llama-factory/src/llamafactory/api/common.py +34 -0
- llama-factory/src/llamafactory/api/protocol.py +153 -0
- llama-factory/src/llamafactory/chat/__init__.py +19 -0
- llama-factory/src/llamafactory/chat/base_engine.py +78 -0
- llama-factory/src/llamafactory/chat/chat_model.py +155 -0
- llama-factory/src/llamafactory/chat/hf_engine.py +343 -0
- llama-factory/src/llamafactory/chat/vllm_engine.py +242 -0
- llama-factory/src/llamafactory/cli.py +121 -0
- llama-factory/src/llamafactory/data/__init__.py +31 -0
- llama-factory/src/llamafactory/data/aligner.py +239 -0
- llama-factory/src/llamafactory/data/collator.py +155 -0
- llama-factory/src/llamafactory/data/data_utils.py +87 -0
- llama-factory/src/llamafactory/data/formatter.py +140 -0
- llama-factory/src/llamafactory/data/loader.py +276 -0
- llama-factory/src/llamafactory/data/parser.py +153 -0
- llama-factory/src/llamafactory/data/preprocess.py +110 -0
- llama-factory/src/llamafactory/data/processors/__init__.py +0 -0
- llama-factory/src/llamafactory/data/processors/feedback.py +143 -0
- llama-factory/src/llamafactory/data/processors/pairwise.py +139 -0
- llama-factory/src/llamafactory/data/processors/pretrain.py +54 -0
- llama-factory/src/llamafactory/data/processors/processor_utils.py +95 -0
- llama-factory/src/llamafactory/data/processors/supervised.py +202 -0
- llama-factory/src/llamafactory/data/processors/unsupervised.py +106 -0
- llama-factory/src/llamafactory/data/template.py +905 -0
- llama-factory/src/llamafactory/data/tool_utils.py +140 -0
- llama-factory/src/llamafactory/eval/__init__.py +0 -0
- llama-factory/src/llamafactory/eval/evaluator.py +154 -0
- llama-factory/src/llamafactory/eval/template.py +81 -0
- llama-factory/src/llamafactory/extras/__init__.py +0 -0
- llama-factory/src/llamafactory/extras/constants.py +1590 -0
- llama-factory/src/llamafactory/extras/env.py +75 -0
- llama-factory/src/llamafactory/extras/logging.py +82 -0
- llama-factory/src/llamafactory/extras/misc.py +228 -0
- llama-factory/src/llamafactory/extras/packages.py +88 -0
- llama-factory/src/llamafactory/extras/ploting.py +101 -0
- llama-factory/src/llamafactory/hparams/__init__.py +32 -0
- llama-factory/src/llamafactory/hparams/data_args.py +143 -0
- llama-factory/src/llamafactory/hparams/evaluation_args.py +62 -0
- llama-factory/src/llamafactory/hparams/finetuning_args.py +400 -0
- llama-factory/src/llamafactory/hparams/generating_args.py +74 -0
.gitignore
CHANGED
@@ -150,3 +150,4 @@ dmypy.json
|
|
150 |
/huggingface_tokenizers_cache
|
151 |
/llama-factory/huggingface_tokenizers_cache
|
152 |
**/Icon?
|
|
|
|
150 |
/huggingface_tokenizers_cache
|
151 |
/llama-factory/huggingface_tokenizers_cache
|
152 |
**/Icon?
|
153 |
+
llama-factory/data/mgtv_train.json
|
llama-factory/README.md
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![# LLaMA Factory](assets/logo.png)
|
2 |
+
|
3 |
+
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
4 |
+
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
5 |
+
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
6 |
+
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
|
7 |
+
[![Citation](https://img.shields.io/badge/citation-72-green)](#projects-using-llama-factory)
|
8 |
+
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
9 |
+
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
10 |
+
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
11 |
+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
12 |
+
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
13 |
+
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
14 |
+
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
15 |
+
|
16 |
+
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
|
17 |
+
|
18 |
+
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
19 |
+
|
20 |
+
\[ English | [中文](README_zh.md) \]
|
21 |
+
|
22 |
+
**Fine-tuning a large language model can be easy as...**
|
23 |
+
|
24 |
+
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
|
25 |
+
|
26 |
+
Choose your path:
|
27 |
+
|
28 |
+
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
29 |
+
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
30 |
+
- **Local machine**: Please refer to [usage](#getting-started)
|
31 |
+
|
32 |
+
## Table of Contents
|
33 |
+
|
34 |
+
- [Features](#features)
|
35 |
+
- [Benchmark](#benchmark)
|
36 |
+
- [Changelog](#changelog)
|
37 |
+
- [Supported Models](#supported-models)
|
38 |
+
- [Supported Training Approaches](#supported-training-approaches)
|
39 |
+
- [Provided Datasets](#provided-datasets)
|
40 |
+
- [Requirement](#requirement)
|
41 |
+
- [Getting Started](#getting-started)
|
42 |
+
- [Projects using LLaMA Factory](#projects-using-llama-factory)
|
43 |
+
- [License](#license)
|
44 |
+
- [Citation](#citation)
|
45 |
+
- [Acknowledgement](#acknowledgement)
|
46 |
+
|
47 |
+
## Features
|
48 |
+
|
49 |
+
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
50 |
+
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
51 |
+
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
52 |
+
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
53 |
+
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
54 |
+
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
55 |
+
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
56 |
+
|
57 |
+
## Benchmark
|
58 |
+
|
59 |
+
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
60 |
+
|
61 |
+
![benchmark](assets/benchmark.svg)
|
62 |
+
|
63 |
+
<details><summary>Definitions</summary>
|
64 |
+
|
65 |
+
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
66 |
+
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
67 |
+
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
68 |
+
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
|
69 |
+
|
70 |
+
</details>
|
71 |
+
|
72 |
+
## Changelog
|
73 |
+
|
74 |
+
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
75 |
+
|
76 |
+
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
77 |
+
|
78 |
+
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
79 |
+
|
80 |
+
<details><summary>Full Changelog</summary>
|
81 |
+
|
82 |
+
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
83 |
+
|
84 |
+
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
85 |
+
|
86 |
+
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
87 |
+
|
88 |
+
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
89 |
+
|
90 |
+
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
91 |
+
|
92 |
+
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
93 |
+
|
94 |
+
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage.
|
95 |
+
|
96 |
+
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
97 |
+
|
98 |
+
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
|
99 |
+
|
100 |
+
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
101 |
+
|
102 |
+
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
|
103 |
+
|
104 |
+
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
105 |
+
|
106 |
+
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage.
|
107 |
+
|
108 |
+
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
109 |
+
|
110 |
+
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
|
111 |
+
|
112 |
+
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
|
113 |
+
|
114 |
+
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
115 |
+
|
116 |
+
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
|
117 |
+
|
118 |
+
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
119 |
+
|
120 |
+
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
121 |
+
|
122 |
+
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
123 |
+
|
124 |
+
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
125 |
+
|
126 |
+
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
|
127 |
+
|
128 |
+
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
|
129 |
+
|
130 |
+
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
131 |
+
|
132 |
+
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
|
133 |
+
|
134 |
+
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
|
135 |
+
|
136 |
+
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
|
137 |
+
|
138 |
+
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
|
139 |
+
|
140 |
+
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
141 |
+
|
142 |
+
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
143 |
+
|
144 |
+
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
|
145 |
+
|
146 |
+
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
147 |
+
|
148 |
+
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
|
149 |
+
|
150 |
+
</details>
|
151 |
+
|
152 |
+
## Supported Models
|
153 |
+
|
154 |
+
| Model | Model size | Template |
|
155 |
+
| ------------------------------------------------------------ | -------------------------------- | --------- |
|
156 |
+
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
157 |
+
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
158 |
+
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
159 |
+
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
160 |
+
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
161 |
+
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
162 |
+
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
163 |
+
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
164 |
+
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
165 |
+
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
166 |
+
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
167 |
+
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
168 |
+
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
169 |
+
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
170 |
+
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
171 |
+
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
172 |
+
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
173 |
+
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
174 |
+
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
175 |
+
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
176 |
+
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
177 |
+
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
178 |
+
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
179 |
+
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
180 |
+
|
181 |
+
> [!NOTE]
|
182 |
+
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
183 |
+
>
|
184 |
+
> Remember to use the **SAME** template in training and inference.
|
185 |
+
|
186 |
+
Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
|
187 |
+
|
188 |
+
You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
|
189 |
+
|
190 |
+
## Supported Training Approaches
|
191 |
+
|
192 |
+
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
193 |
+
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
194 |
+
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
195 |
+
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
196 |
+
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
197 |
+
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
198 |
+
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
199 |
+
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
200 |
+
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
201 |
+
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
202 |
+
|
203 |
+
## Provided Datasets
|
204 |
+
|
205 |
+
<details><summary>Pre-training datasets</summary>
|
206 |
+
|
207 |
+
- [Wiki Demo (en)](data/wiki_demo.txt)
|
208 |
+
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
209 |
+
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
210 |
+
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
211 |
+
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
212 |
+
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
213 |
+
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
214 |
+
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
215 |
+
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
216 |
+
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
217 |
+
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
218 |
+
|
219 |
+
</details>
|
220 |
+
|
221 |
+
<details><summary>Supervised fine-tuning datasets</summary>
|
222 |
+
|
223 |
+
- [Identity (en&zh)](data/identity.json)
|
224 |
+
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
225 |
+
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
226 |
+
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
227 |
+
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
228 |
+
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
229 |
+
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
230 |
+
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
231 |
+
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
232 |
+
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
233 |
+
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
234 |
+
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
235 |
+
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
236 |
+
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
237 |
+
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
238 |
+
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
239 |
+
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
240 |
+
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
241 |
+
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
242 |
+
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
243 |
+
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
244 |
+
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
245 |
+
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
246 |
+
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
247 |
+
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
248 |
+
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
249 |
+
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
250 |
+
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
251 |
+
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
252 |
+
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
253 |
+
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
254 |
+
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
255 |
+
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
256 |
+
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
257 |
+
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
258 |
+
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
259 |
+
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
260 |
+
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
261 |
+
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
262 |
+
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
263 |
+
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
264 |
+
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
265 |
+
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
266 |
+
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
267 |
+
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
268 |
+
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
269 |
+
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
270 |
+
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
271 |
+
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
272 |
+
|
273 |
+
</details>
|
274 |
+
|
275 |
+
<details><summary>Preference datasets</summary>
|
276 |
+
|
277 |
+
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
278 |
+
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
279 |
+
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
280 |
+
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
281 |
+
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
282 |
+
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
283 |
+
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
284 |
+
|
285 |
+
</details>
|
286 |
+
|
287 |
+
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
288 |
+
|
289 |
+
```bash
|
290 |
+
pip install --upgrade huggingface_hub
|
291 |
+
huggingface-cli login
|
292 |
+
```
|
293 |
+
|
294 |
+
## Requirement
|
295 |
+
|
296 |
+
| Mandatory | Minimum | Recommend |
|
297 |
+
| ------------ | ------- | --------- |
|
298 |
+
| python | 3.8 | 3.11 |
|
299 |
+
| torch | 1.13.1 | 2.3.0 |
|
300 |
+
| transformers | 4.41.2 | 4.41.2 |
|
301 |
+
| datasets | 2.16.0 | 2.19.2 |
|
302 |
+
| accelerate | 0.30.1 | 0.30.1 |
|
303 |
+
| peft | 0.11.1 | 0.11.1 |
|
304 |
+
| trl | 0.8.6 | 0.9.4 |
|
305 |
+
|
306 |
+
| Optional | Minimum | Recommend |
|
307 |
+
| ------------ | ------- | --------- |
|
308 |
+
| CUDA | 11.6 | 12.2 |
|
309 |
+
| deepspeed | 0.10.0 | 0.14.0 |
|
310 |
+
| bitsandbytes | 0.39.0 | 0.43.1 |
|
311 |
+
| vllm | 0.4.3 | 0.4.3 |
|
312 |
+
| flash-attn | 2.3.0 | 2.5.9 |
|
313 |
+
|
314 |
+
### Hardware Requirement
|
315 |
+
|
316 |
+
\* *estimated*
|
317 |
+
|
318 |
+
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
319 |
+
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
320 |
+
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
321 |
+
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
322 |
+
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
323 |
+
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
324 |
+
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
325 |
+
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
326 |
+
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
327 |
+
|
328 |
+
## Getting Started
|
329 |
+
|
330 |
+
### Installation
|
331 |
+
|
332 |
+
> [!IMPORTANT]
|
333 |
+
> Installation is mandatory.
|
334 |
+
|
335 |
+
```bash
|
336 |
+
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
337 |
+
cd LLaMA-Factory
|
338 |
+
pip install -e ".[torch,metrics]"
|
339 |
+
```
|
340 |
+
|
341 |
+
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
|
342 |
+
|
343 |
+
> [!TIP]
|
344 |
+
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
345 |
+
|
346 |
+
<details><summary>For Windows users</summary>
|
347 |
+
|
348 |
+
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
349 |
+
|
350 |
+
```bash
|
351 |
+
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
352 |
+
```
|
353 |
+
|
354 |
+
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
355 |
+
|
356 |
+
</details>
|
357 |
+
|
358 |
+
<details><summary>For Ascend NPU users</summary>
|
359 |
+
|
360 |
+
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
361 |
+
|
362 |
+
```bash
|
363 |
+
# replace the url according to your CANN version and devices
|
364 |
+
# install CANN Toolkit
|
365 |
+
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
366 |
+
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
367 |
+
|
368 |
+
# install CANN Kernels
|
369 |
+
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
370 |
+
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
371 |
+
|
372 |
+
# set env variables
|
373 |
+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
374 |
+
```
|
375 |
+
|
376 |
+
| Requirement | Minimum | Recommend |
|
377 |
+
| ------------ | ------- | ----------- |
|
378 |
+
| CANN | 8.0.RC1 | 8.0.RC1 |
|
379 |
+
| torch | 2.1.0 | 2.1.0 |
|
380 |
+
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
381 |
+
| deepspeed | 0.13.2 | 0.13.2 |
|
382 |
+
|
383 |
+
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
384 |
+
|
385 |
+
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
386 |
+
|
387 |
+
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
388 |
+
|
389 |
+
</details>
|
390 |
+
|
391 |
+
### Data Preparation
|
392 |
+
|
393 |
+
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
394 |
+
|
395 |
+
> [!NOTE]
|
396 |
+
> Please update `data/dataset_info.json` to use your custom dataset.
|
397 |
+
|
398 |
+
### Quickstart
|
399 |
+
|
400 |
+
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
401 |
+
|
402 |
+
```bash
|
403 |
+
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
404 |
+
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
405 |
+
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
406 |
+
```
|
407 |
+
|
408 |
+
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
409 |
+
|
410 |
+
> [!TIP]
|
411 |
+
> Use `llamafactory-cli help` to show help information.
|
412 |
+
|
413 |
+
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
414 |
+
|
415 |
+
```bash
|
416 |
+
llamafactory-cli webui
|
417 |
+
```
|
418 |
+
|
419 |
+
### Build Docker
|
420 |
+
|
421 |
+
For CUDA users:
|
422 |
+
|
423 |
+
```bash
|
424 |
+
cd docker/docker-cuda/
|
425 |
+
docker-compose up -d
|
426 |
+
docker-compose exec llamafactory bash
|
427 |
+
```
|
428 |
+
|
429 |
+
For Ascend NPU users:
|
430 |
+
|
431 |
+
```bash
|
432 |
+
cd docker/docker-npu/
|
433 |
+
docker-compose up -d
|
434 |
+
docker-compose exec llamafactory bash
|
435 |
+
```
|
436 |
+
|
437 |
+
<details><summary>Build without Docker Compose</summary>
|
438 |
+
|
439 |
+
For CUDA users:
|
440 |
+
|
441 |
+
```bash
|
442 |
+
docker build -f ./docker/docker-cuda/Dockerfile \
|
443 |
+
--build-arg INSTALL_BNB=false \
|
444 |
+
--build-arg INSTALL_VLLM=false \
|
445 |
+
--build-arg INSTALL_DEEPSPEED=false \
|
446 |
+
--build-arg INSTALL_FLASHATTN=false \
|
447 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
448 |
+
-t llamafactory:latest .
|
449 |
+
|
450 |
+
docker run -dit --gpus=all \
|
451 |
+
-v ./hf_cache:/root/.cache/huggingface \
|
452 |
+
-v ./ms_cache:/root/.cache/modelscope \
|
453 |
+
-v ./data:/app/data \
|
454 |
+
-v ./output:/app/output \
|
455 |
+
-p 7860:7860 \
|
456 |
+
-p 8000:8000 \
|
457 |
+
--shm-size 16G \
|
458 |
+
--name llamafactory \
|
459 |
+
llamafactory:latest
|
460 |
+
|
461 |
+
docker exec -it llamafactory bash
|
462 |
+
```
|
463 |
+
|
464 |
+
For Ascend NPU users:
|
465 |
+
|
466 |
+
```bash
|
467 |
+
# Choose docker image upon your environment
|
468 |
+
docker build -f ./docker/docker-npu/Dockerfile \
|
469 |
+
--build-arg INSTALL_DEEPSPEED=false \
|
470 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
471 |
+
-t llamafactory:latest .
|
472 |
+
|
473 |
+
# Change `device` upon your resources
|
474 |
+
docker run -dit \
|
475 |
+
-v ./hf_cache:/root/.cache/huggingface \
|
476 |
+
-v ./ms_cache:/root/.cache/modelscope \
|
477 |
+
-v ./data:/app/data \
|
478 |
+
-v ./output:/app/output \
|
479 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
480 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
481 |
+
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
482 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
483 |
+
-p 7860:7860 \
|
484 |
+
-p 8000:8000 \
|
485 |
+
--device /dev/davinci0 \
|
486 |
+
--device /dev/davinci_manager \
|
487 |
+
--device /dev/devmm_svm \
|
488 |
+
--device /dev/hisi_hdc \
|
489 |
+
--shm-size 16G \
|
490 |
+
--name llamafactory \
|
491 |
+
llamafactory:latest
|
492 |
+
|
493 |
+
docker exec -it llamafactory bash
|
494 |
+
```
|
495 |
+
|
496 |
+
</details>
|
497 |
+
|
498 |
+
<details><summary>Details about volume</summary>
|
499 |
+
|
500 |
+
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
501 |
+
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
502 |
+
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
503 |
+
|
504 |
+
</details>
|
505 |
+
|
506 |
+
### Deploy with OpenAI-style API and vLLM
|
507 |
+
|
508 |
+
```bash
|
509 |
+
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
510 |
+
```
|
511 |
+
|
512 |
+
> [!TIP]
|
513 |
+
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
|
514 |
+
|
515 |
+
### Download from ModelScope Hub
|
516 |
+
|
517 |
+
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
518 |
+
|
519 |
+
```bash
|
520 |
+
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
521 |
+
```
|
522 |
+
|
523 |
+
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
524 |
+
|
525 |
+
### Use W&B Logger
|
526 |
+
|
527 |
+
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
528 |
+
|
529 |
+
```yaml
|
530 |
+
report_to: wandb
|
531 |
+
run_name: test_run # optional
|
532 |
+
```
|
533 |
+
|
534 |
+
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
535 |
+
|
536 |
+
## Projects using LLaMA Factory
|
537 |
+
|
538 |
+
If you have a project that should be incorporated, please contact via email or create a pull request.
|
539 |
+
|
540 |
+
<details><summary>Click to show</summary>
|
541 |
+
|
542 |
+
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
543 |
+
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
544 |
+
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
545 |
+
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
546 |
+
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
547 |
+
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
548 |
+
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
549 |
+
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
550 |
+
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
551 |
+
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
552 |
+
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
553 |
+
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
554 |
+
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
555 |
+
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
556 |
+
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
557 |
+
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
558 |
+
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
559 |
+
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
560 |
+
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
561 |
+
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
562 |
+
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
563 |
+
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
564 |
+
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
565 |
+
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
566 |
+
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
567 |
+
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
568 |
+
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
569 |
+
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
570 |
+
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
571 |
+
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
572 |
+
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
573 |
+
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
574 |
+
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
575 |
+
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
576 |
+
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
577 |
+
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
578 |
+
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
579 |
+
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
580 |
+
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
581 |
+
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
582 |
+
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
583 |
+
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
584 |
+
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
585 |
+
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
586 |
+
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
587 |
+
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
588 |
+
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
589 |
+
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
590 |
+
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
591 |
+
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
592 |
+
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
593 |
+
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
594 |
+
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
595 |
+
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
596 |
+
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
597 |
+
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
598 |
+
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
599 |
+
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
600 |
+
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
601 |
+
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
602 |
+
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
603 |
+
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
604 |
+
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
605 |
+
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
606 |
+
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
607 |
+
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
608 |
+
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
609 |
+
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
610 |
+
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
611 |
+
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
612 |
+
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
613 |
+
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
614 |
+
|
615 |
+
</details>
|
616 |
+
|
617 |
+
## License
|
618 |
+
|
619 |
+
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
620 |
+
|
621 |
+
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
622 |
+
|
623 |
+
## Citation
|
624 |
+
|
625 |
+
If this work is helpful, please kindly cite as:
|
626 |
+
|
627 |
+
```bibtex
|
628 |
+
@inproceedings{zheng2024llamafactory,
|
629 |
+
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
630 |
+
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
631 |
+
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
632 |
+
address={Bangkok, Thailand},
|
633 |
+
publisher={Association for Computational Linguistics},
|
634 |
+
year={2024},
|
635 |
+
url={http://arxiv.org/abs/2403.13372}
|
636 |
+
}
|
637 |
+
```
|
638 |
+
|
639 |
+
## Acknowledgement
|
640 |
+
|
641 |
+
This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
642 |
+
|
643 |
+
## Star History
|
644 |
+
|
645 |
+
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
|
llama-factory/pyproject.toml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[tool.ruff]
|
6 |
+
target-version = "py38"
|
7 |
+
line-length = 119
|
8 |
+
indent-width = 4
|
9 |
+
|
10 |
+
[tool.ruff.lint]
|
11 |
+
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
12 |
+
select = ["C", "E", "F", "I", "W"]
|
13 |
+
|
14 |
+
[tool.ruff.lint.isort]
|
15 |
+
lines-after-imports = 2
|
16 |
+
known-first-party = ["llamafactory"]
|
17 |
+
known-third-party = [
|
18 |
+
"accelerate",
|
19 |
+
"datasets",
|
20 |
+
"gradio",
|
21 |
+
"numpy",
|
22 |
+
"peft",
|
23 |
+
"torch",
|
24 |
+
"transformers",
|
25 |
+
"trl"
|
26 |
+
]
|
27 |
+
|
28 |
+
[tool.ruff.format]
|
29 |
+
quote-style = "double"
|
30 |
+
indent-style = "space"
|
31 |
+
docstring-code-format = true
|
32 |
+
skip-magic-trailing-comma = false
|
33 |
+
line-ending = "auto"
|
llama-factory/requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.41.2
|
2 |
+
datasets>=2.16.0
|
3 |
+
accelerate>=0.30.1
|
4 |
+
peft>=0.11.1
|
5 |
+
trl>=0.8.6
|
6 |
+
gradio>=4.0.0
|
7 |
+
pandas>=2.0.0
|
8 |
+
scipy
|
9 |
+
einops
|
10 |
+
sentencepiece
|
11 |
+
tiktoken
|
12 |
+
protobuf
|
13 |
+
uvicorn
|
14 |
+
pydantic
|
15 |
+
fastapi
|
16 |
+
sse-starlette
|
17 |
+
matplotlib>=3.7.0
|
18 |
+
fire
|
19 |
+
packaging
|
20 |
+
pyyaml
|
21 |
+
numpy<2.0.0
|
llama-factory/setup.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import re
|
17 |
+
|
18 |
+
from setuptools import find_packages, setup
|
19 |
+
|
20 |
+
|
21 |
+
def get_version():
|
22 |
+
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
23 |
+
file_content = f.read()
|
24 |
+
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
25 |
+
(version,) = re.findall(pattern, file_content)
|
26 |
+
return version
|
27 |
+
|
28 |
+
|
29 |
+
def get_requires():
|
30 |
+
with open("requirements.txt", "r", encoding="utf-8") as f:
|
31 |
+
file_content = f.read()
|
32 |
+
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
33 |
+
return lines
|
34 |
+
|
35 |
+
|
36 |
+
extra_require = {
|
37 |
+
"torch": ["torch>=1.13.1"],
|
38 |
+
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
39 |
+
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
40 |
+
"deepspeed": ["deepspeed>=0.10.0"],
|
41 |
+
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
42 |
+
"hqq": ["hqq"],
|
43 |
+
"eetq": ["eetq"],
|
44 |
+
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
45 |
+
"awq": ["autoawq"],
|
46 |
+
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
47 |
+
"vllm": ["vllm>=0.4.3"],
|
48 |
+
"galore": ["galore-torch"],
|
49 |
+
"badam": ["badam>=1.2.1"],
|
50 |
+
"qwen": ["transformers_stream_generator"],
|
51 |
+
"modelscope": ["modelscope"],
|
52 |
+
"dev": ["ruff", "pytest"],
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def main():
|
57 |
+
setup(
|
58 |
+
name="llamafactory",
|
59 |
+
version=get_version(),
|
60 |
+
author="hiyouga",
|
61 |
+
author_email="hiyouga" "@" "buaa.edu.cn",
|
62 |
+
description="Easy-to-use LLM fine-tuning framework",
|
63 |
+
long_description=open("README.md", "r", encoding="utf-8").read(),
|
64 |
+
long_description_content_type="text/markdown",
|
65 |
+
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
66 |
+
license="Apache 2.0 License",
|
67 |
+
url="https://github.com/hiyouga/LLaMA-Factory",
|
68 |
+
package_dir={"": "src"},
|
69 |
+
packages=find_packages("src"),
|
70 |
+
python_requires=">=3.8.0",
|
71 |
+
install_requires=get_requires(),
|
72 |
+
extras_require=extra_require,
|
73 |
+
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
|
74 |
+
classifiers=[
|
75 |
+
"Development Status :: 4 - Beta",
|
76 |
+
"Intended Audience :: Developers",
|
77 |
+
"Intended Audience :: Education",
|
78 |
+
"Intended Audience :: Science/Research",
|
79 |
+
"License :: OSI Approved :: Apache Software License",
|
80 |
+
"Operating System :: OS Independent",
|
81 |
+
"Programming Language :: Python :: 3",
|
82 |
+
"Programming Language :: Python :: 3.8",
|
83 |
+
"Programming Language :: Python :: 3.9",
|
84 |
+
"Programming Language :: Python :: 3.10",
|
85 |
+
"Programming Language :: Python :: 3.11",
|
86 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
87 |
+
],
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
main()
|
llama-factory/src/api.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import uvicorn
|
18 |
+
|
19 |
+
from llamafactory.api.app import create_app
|
20 |
+
from llamafactory.chat import ChatModel
|
21 |
+
|
22 |
+
|
23 |
+
def main():
|
24 |
+
chat_model = ChatModel()
|
25 |
+
app = create_app(chat_model)
|
26 |
+
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
27 |
+
api_port = int(os.environ.get("API_PORT", "8000"))
|
28 |
+
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
29 |
+
uvicorn.run(app, host=api_host, port=api_port)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
main()
|
llama-factory/src/llamafactory/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""
|
16 |
+
Efficient fine-tuning of large language models.
|
17 |
+
|
18 |
+
Level:
|
19 |
+
api, webui > chat, eval, train > data, model > hparams > extras
|
20 |
+
|
21 |
+
Dependency graph:
|
22 |
+
main:
|
23 |
+
transformers>=4.41.2
|
24 |
+
datasets>=2.16.0
|
25 |
+
accelerate>=0.30.1
|
26 |
+
peft>=0.11.1
|
27 |
+
trl>=0.8.6
|
28 |
+
attention:
|
29 |
+
transformers>=4.42.4 (gemma+fa2)
|
30 |
+
longlora:
|
31 |
+
transformers>=4.41.2,<=4.42.4
|
32 |
+
packing:
|
33 |
+
transformers>=4.41.2,<=4.42.4
|
34 |
+
patcher:
|
35 |
+
transformers==4.41.2 (chatglm)
|
36 |
+
"""
|
37 |
+
|
38 |
+
from .cli import VERSION
|
39 |
+
|
40 |
+
|
41 |
+
__version__ = VERSION
|
llama-factory/src/llamafactory/api/__init__.py
ADDED
File without changes
|
llama-factory/src/llamafactory/api/app.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from contextlib import asynccontextmanager
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
from typing_extensions import Annotated
|
20 |
+
|
21 |
+
from ..chat import ChatModel
|
22 |
+
from ..extras.misc import torch_gc
|
23 |
+
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
24 |
+
from .chat import (
|
25 |
+
create_chat_completion_response,
|
26 |
+
create_score_evaluation_response,
|
27 |
+
create_stream_chat_completion_response,
|
28 |
+
)
|
29 |
+
from .protocol import (
|
30 |
+
ChatCompletionRequest,
|
31 |
+
ChatCompletionResponse,
|
32 |
+
ModelCard,
|
33 |
+
ModelList,
|
34 |
+
ScoreEvaluationRequest,
|
35 |
+
ScoreEvaluationResponse,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
if is_fastapi_available():
|
40 |
+
from fastapi import Depends, FastAPI, HTTPException, status
|
41 |
+
from fastapi.middleware.cors import CORSMiddleware
|
42 |
+
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
43 |
+
|
44 |
+
|
45 |
+
if is_starlette_available():
|
46 |
+
from sse_starlette import EventSourceResponse
|
47 |
+
|
48 |
+
|
49 |
+
if is_uvicorn_available():
|
50 |
+
import uvicorn
|
51 |
+
|
52 |
+
|
53 |
+
@asynccontextmanager
|
54 |
+
async def lifespan(app: "FastAPI"): # collects GPU memory
|
55 |
+
yield
|
56 |
+
torch_gc()
|
57 |
+
|
58 |
+
|
59 |
+
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
60 |
+
app = FastAPI(lifespan=lifespan)
|
61 |
+
app.add_middleware(
|
62 |
+
CORSMiddleware,
|
63 |
+
allow_origins=["*"],
|
64 |
+
allow_credentials=True,
|
65 |
+
allow_methods=["*"],
|
66 |
+
allow_headers=["*"],
|
67 |
+
)
|
68 |
+
api_key = os.environ.get("API_KEY")
|
69 |
+
security = HTTPBearer(auto_error=False)
|
70 |
+
|
71 |
+
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
72 |
+
if api_key and (auth is None or auth.credentials != api_key):
|
73 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
74 |
+
|
75 |
+
@app.get(
|
76 |
+
"/v1/models",
|
77 |
+
response_model=ModelList,
|
78 |
+
status_code=status.HTTP_200_OK,
|
79 |
+
dependencies=[Depends(verify_api_key)],
|
80 |
+
)
|
81 |
+
async def list_models():
|
82 |
+
model_card = ModelCard(id="gpt-3.5-turbo")
|
83 |
+
return ModelList(data=[model_card])
|
84 |
+
|
85 |
+
@app.post(
|
86 |
+
"/v1/chat/completions",
|
87 |
+
response_model=ChatCompletionResponse,
|
88 |
+
status_code=status.HTTP_200_OK,
|
89 |
+
dependencies=[Depends(verify_api_key)],
|
90 |
+
)
|
91 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
92 |
+
if not chat_model.engine.can_generate:
|
93 |
+
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
94 |
+
|
95 |
+
if request.stream:
|
96 |
+
generate = create_stream_chat_completion_response(request, chat_model)
|
97 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
98 |
+
else:
|
99 |
+
return await create_chat_completion_response(request, chat_model)
|
100 |
+
|
101 |
+
@app.post(
|
102 |
+
"/v1/score/evaluation",
|
103 |
+
response_model=ScoreEvaluationResponse,
|
104 |
+
status_code=status.HTTP_200_OK,
|
105 |
+
dependencies=[Depends(verify_api_key)],
|
106 |
+
)
|
107 |
+
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
108 |
+
if chat_model.engine.can_generate:
|
109 |
+
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
110 |
+
|
111 |
+
return await create_score_evaluation_response(request, chat_model)
|
112 |
+
|
113 |
+
return app
|
114 |
+
|
115 |
+
|
116 |
+
def run_api() -> None:
|
117 |
+
chat_model = ChatModel()
|
118 |
+
app = create_app(chat_model)
|
119 |
+
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
120 |
+
api_port = int(os.environ.get("API_PORT", "8000"))
|
121 |
+
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
122 |
+
uvicorn.run(app, host=api_host, port=api_port)
|
llama-factory/src/llamafactory/api/chat.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import base64
|
16 |
+
import io
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import uuid
|
20 |
+
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
21 |
+
|
22 |
+
from ..data import Role as DataRole
|
23 |
+
from ..extras.logging import get_logger
|
24 |
+
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
25 |
+
from .common import dictify, jsonify
|
26 |
+
from .protocol import (
|
27 |
+
ChatCompletionMessage,
|
28 |
+
ChatCompletionResponse,
|
29 |
+
ChatCompletionResponseChoice,
|
30 |
+
ChatCompletionResponseUsage,
|
31 |
+
ChatCompletionStreamResponse,
|
32 |
+
ChatCompletionStreamResponseChoice,
|
33 |
+
Finish,
|
34 |
+
Function,
|
35 |
+
FunctionCall,
|
36 |
+
Role,
|
37 |
+
ScoreEvaluationResponse,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
if is_fastapi_available():
|
42 |
+
from fastapi import HTTPException, status
|
43 |
+
|
44 |
+
|
45 |
+
if is_pillow_available():
|
46 |
+
from PIL import Image
|
47 |
+
|
48 |
+
|
49 |
+
if is_requests_available():
|
50 |
+
import requests
|
51 |
+
|
52 |
+
|
53 |
+
if TYPE_CHECKING:
|
54 |
+
from numpy.typing import NDArray
|
55 |
+
|
56 |
+
from ..chat import ChatModel
|
57 |
+
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
58 |
+
|
59 |
+
|
60 |
+
logger = get_logger(__name__)
|
61 |
+
ROLE_MAPPING = {
|
62 |
+
Role.USER: DataRole.USER.value,
|
63 |
+
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
64 |
+
Role.SYSTEM: DataRole.SYSTEM.value,
|
65 |
+
Role.FUNCTION: DataRole.FUNCTION.value,
|
66 |
+
Role.TOOL: DataRole.OBSERVATION.value,
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
def _process_request(
|
71 |
+
request: "ChatCompletionRequest",
|
72 |
+
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
73 |
+
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
74 |
+
|
75 |
+
if len(request.messages) == 0:
|
76 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
77 |
+
|
78 |
+
if request.messages[0].role == Role.SYSTEM:
|
79 |
+
system = request.messages.pop(0).content
|
80 |
+
else:
|
81 |
+
system = None
|
82 |
+
|
83 |
+
if len(request.messages) % 2 == 0:
|
84 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
85 |
+
|
86 |
+
input_messages = []
|
87 |
+
image = None
|
88 |
+
for i, message in enumerate(request.messages):
|
89 |
+
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
90 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
91 |
+
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
92 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
93 |
+
|
94 |
+
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
95 |
+
tool_calls = [
|
96 |
+
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
97 |
+
for tool_call in message.tool_calls
|
98 |
+
]
|
99 |
+
content = json.dumps(tool_calls, ensure_ascii=False)
|
100 |
+
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
101 |
+
elif isinstance(message.content, list):
|
102 |
+
for input_item in message.content:
|
103 |
+
if input_item.type == "text":
|
104 |
+
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
105 |
+
else:
|
106 |
+
image_url = input_item.image_url.url
|
107 |
+
if image_url.startswith("data:image"): # base64 image
|
108 |
+
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
|
109 |
+
image_path = io.BytesIO(image_data)
|
110 |
+
elif os.path.isfile(image_url): # local file
|
111 |
+
image_path = open(image_url, "rb")
|
112 |
+
else: # web uri
|
113 |
+
image_path = requests.get(image_url, stream=True).raw
|
114 |
+
|
115 |
+
image = Image.open(image_path).convert("RGB")
|
116 |
+
else:
|
117 |
+
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
118 |
+
|
119 |
+
tool_list = request.tools
|
120 |
+
if isinstance(tool_list, list) and len(tool_list):
|
121 |
+
try:
|
122 |
+
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
123 |
+
except json.JSONDecodeError:
|
124 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
125 |
+
else:
|
126 |
+
tools = None
|
127 |
+
|
128 |
+
return input_messages, system, tools, image
|
129 |
+
|
130 |
+
|
131 |
+
def _create_stream_chat_completion_chunk(
|
132 |
+
completion_id: str,
|
133 |
+
model: str,
|
134 |
+
delta: "ChatCompletionMessage",
|
135 |
+
index: Optional[int] = 0,
|
136 |
+
finish_reason: Optional["Finish"] = None,
|
137 |
+
) -> str:
|
138 |
+
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
139 |
+
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
140 |
+
return jsonify(chunk)
|
141 |
+
|
142 |
+
|
143 |
+
async def create_chat_completion_response(
|
144 |
+
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
145 |
+
) -> "ChatCompletionResponse":
|
146 |
+
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
147 |
+
input_messages, system, tools, image = _process_request(request)
|
148 |
+
responses = await chat_model.achat(
|
149 |
+
input_messages,
|
150 |
+
system,
|
151 |
+
tools,
|
152 |
+
image,
|
153 |
+
do_sample=request.do_sample,
|
154 |
+
temperature=request.temperature,
|
155 |
+
top_p=request.top_p,
|
156 |
+
max_new_tokens=request.max_tokens,
|
157 |
+
num_return_sequences=request.n,
|
158 |
+
stop=request.stop,
|
159 |
+
)
|
160 |
+
|
161 |
+
prompt_length, response_length = 0, 0
|
162 |
+
choices = []
|
163 |
+
for i, response in enumerate(responses):
|
164 |
+
if tools:
|
165 |
+
result = chat_model.engine.template.extract_tool(response.response_text)
|
166 |
+
else:
|
167 |
+
result = response.response_text
|
168 |
+
|
169 |
+
if isinstance(result, list):
|
170 |
+
tool_calls = []
|
171 |
+
for tool in result:
|
172 |
+
function = Function(name=tool[0], arguments=tool[1])
|
173 |
+
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
174 |
+
|
175 |
+
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
176 |
+
finish_reason = Finish.TOOL
|
177 |
+
else:
|
178 |
+
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
179 |
+
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
180 |
+
|
181 |
+
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
182 |
+
prompt_length = response.prompt_length
|
183 |
+
response_length += response.response_length
|
184 |
+
|
185 |
+
usage = ChatCompletionResponseUsage(
|
186 |
+
prompt_tokens=prompt_length,
|
187 |
+
completion_tokens=response_length,
|
188 |
+
total_tokens=prompt_length + response_length,
|
189 |
+
)
|
190 |
+
|
191 |
+
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
192 |
+
|
193 |
+
|
194 |
+
async def create_stream_chat_completion_response(
|
195 |
+
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
196 |
+
) -> AsyncGenerator[str, None]:
|
197 |
+
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
198 |
+
input_messages, system, tools, image = _process_request(request)
|
199 |
+
if tools:
|
200 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
201 |
+
|
202 |
+
if request.n > 1:
|
203 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
|
204 |
+
|
205 |
+
yield _create_stream_chat_completion_chunk(
|
206 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
207 |
+
)
|
208 |
+
async for new_token in chat_model.astream_chat(
|
209 |
+
input_messages,
|
210 |
+
system,
|
211 |
+
tools,
|
212 |
+
image,
|
213 |
+
do_sample=request.do_sample,
|
214 |
+
temperature=request.temperature,
|
215 |
+
top_p=request.top_p,
|
216 |
+
max_new_tokens=request.max_tokens,
|
217 |
+
stop=request.stop,
|
218 |
+
):
|
219 |
+
if len(new_token) != 0:
|
220 |
+
yield _create_stream_chat_completion_chunk(
|
221 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
222 |
+
)
|
223 |
+
|
224 |
+
yield _create_stream_chat_completion_chunk(
|
225 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
226 |
+
)
|
227 |
+
yield "[DONE]"
|
228 |
+
|
229 |
+
|
230 |
+
async def create_score_evaluation_response(
|
231 |
+
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
232 |
+
) -> "ScoreEvaluationResponse":
|
233 |
+
if len(request.messages) == 0:
|
234 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
235 |
+
|
236 |
+
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
237 |
+
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
llama-factory/src/llamafactory/api/common.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
from typing import TYPE_CHECKING, Any, Dict
|
17 |
+
|
18 |
+
|
19 |
+
if TYPE_CHECKING:
|
20 |
+
from pydantic import BaseModel
|
21 |
+
|
22 |
+
|
23 |
+
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
24 |
+
try: # pydantic v2
|
25 |
+
return data.model_dump(exclude_unset=True)
|
26 |
+
except AttributeError: # pydantic v1
|
27 |
+
return data.dict(exclude_unset=True)
|
28 |
+
|
29 |
+
|
30 |
+
def jsonify(data: "BaseModel") -> str:
|
31 |
+
try: # pydantic v2
|
32 |
+
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
33 |
+
except AttributeError: # pydantic v1
|
34 |
+
return data.json(exclude_unset=True, ensure_ascii=False)
|
llama-factory/src/llamafactory/api/protocol.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import time
|
16 |
+
from enum import Enum, unique
|
17 |
+
from typing import Any, Dict, List, Optional, Union
|
18 |
+
|
19 |
+
from pydantic import BaseModel, Field
|
20 |
+
from typing_extensions import Literal
|
21 |
+
|
22 |
+
|
23 |
+
@unique
|
24 |
+
class Role(str, Enum):
|
25 |
+
USER = "user"
|
26 |
+
ASSISTANT = "assistant"
|
27 |
+
SYSTEM = "system"
|
28 |
+
FUNCTION = "function"
|
29 |
+
TOOL = "tool"
|
30 |
+
|
31 |
+
|
32 |
+
@unique
|
33 |
+
class Finish(str, Enum):
|
34 |
+
STOP = "stop"
|
35 |
+
LENGTH = "length"
|
36 |
+
TOOL = "tool_calls"
|
37 |
+
|
38 |
+
|
39 |
+
class ModelCard(BaseModel):
|
40 |
+
id: str
|
41 |
+
object: Literal["model"] = "model"
|
42 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
43 |
+
owned_by: Literal["owner"] = "owner"
|
44 |
+
|
45 |
+
|
46 |
+
class ModelList(BaseModel):
|
47 |
+
object: Literal["list"] = "list"
|
48 |
+
data: List[ModelCard] = []
|
49 |
+
|
50 |
+
|
51 |
+
class Function(BaseModel):
|
52 |
+
name: str
|
53 |
+
arguments: str
|
54 |
+
|
55 |
+
|
56 |
+
class FunctionDefinition(BaseModel):
|
57 |
+
name: str
|
58 |
+
description: str
|
59 |
+
parameters: Dict[str, Any]
|
60 |
+
|
61 |
+
|
62 |
+
class FunctionAvailable(BaseModel):
|
63 |
+
type: Literal["function", "code_interpreter"] = "function"
|
64 |
+
function: Optional[FunctionDefinition] = None
|
65 |
+
|
66 |
+
|
67 |
+
class FunctionCall(BaseModel):
|
68 |
+
id: str
|
69 |
+
type: Literal["function"] = "function"
|
70 |
+
function: Function
|
71 |
+
|
72 |
+
|
73 |
+
class ImageURL(BaseModel):
|
74 |
+
url: str
|
75 |
+
|
76 |
+
|
77 |
+
class MultimodalInputItem(BaseModel):
|
78 |
+
type: Literal["text", "image_url"]
|
79 |
+
text: Optional[str] = None
|
80 |
+
image_url: Optional[ImageURL] = None
|
81 |
+
|
82 |
+
|
83 |
+
class ChatMessage(BaseModel):
|
84 |
+
role: Role
|
85 |
+
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
86 |
+
tool_calls: Optional[List[FunctionCall]] = None
|
87 |
+
|
88 |
+
|
89 |
+
class ChatCompletionMessage(BaseModel):
|
90 |
+
role: Optional[Role] = None
|
91 |
+
content: Optional[str] = None
|
92 |
+
tool_calls: Optional[List[FunctionCall]] = None
|
93 |
+
|
94 |
+
|
95 |
+
class ChatCompletionRequest(BaseModel):
|
96 |
+
model: str
|
97 |
+
messages: List[ChatMessage]
|
98 |
+
tools: Optional[List[FunctionAvailable]] = None
|
99 |
+
do_sample: Optional[bool] = None
|
100 |
+
temperature: Optional[float] = None
|
101 |
+
top_p: Optional[float] = None
|
102 |
+
n: int = 1
|
103 |
+
max_tokens: Optional[int] = None
|
104 |
+
stop: Optional[Union[str, List[str]]] = None
|
105 |
+
stream: bool = False
|
106 |
+
|
107 |
+
|
108 |
+
class ChatCompletionResponseChoice(BaseModel):
|
109 |
+
index: int
|
110 |
+
message: ChatCompletionMessage
|
111 |
+
finish_reason: Finish
|
112 |
+
|
113 |
+
|
114 |
+
class ChatCompletionStreamResponseChoice(BaseModel):
|
115 |
+
index: int
|
116 |
+
delta: ChatCompletionMessage
|
117 |
+
finish_reason: Optional[Finish] = None
|
118 |
+
|
119 |
+
|
120 |
+
class ChatCompletionResponseUsage(BaseModel):
|
121 |
+
prompt_tokens: int
|
122 |
+
completion_tokens: int
|
123 |
+
total_tokens: int
|
124 |
+
|
125 |
+
|
126 |
+
class ChatCompletionResponse(BaseModel):
|
127 |
+
id: str
|
128 |
+
object: Literal["chat.completion"] = "chat.completion"
|
129 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
130 |
+
model: str
|
131 |
+
choices: List[ChatCompletionResponseChoice]
|
132 |
+
usage: ChatCompletionResponseUsage
|
133 |
+
|
134 |
+
|
135 |
+
class ChatCompletionStreamResponse(BaseModel):
|
136 |
+
id: str
|
137 |
+
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
138 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
139 |
+
model: str
|
140 |
+
choices: List[ChatCompletionStreamResponseChoice]
|
141 |
+
|
142 |
+
|
143 |
+
class ScoreEvaluationRequest(BaseModel):
|
144 |
+
model: str
|
145 |
+
messages: List[str]
|
146 |
+
max_length: Optional[int] = None
|
147 |
+
|
148 |
+
|
149 |
+
class ScoreEvaluationResponse(BaseModel):
|
150 |
+
id: str
|
151 |
+
object: Literal["score.evaluation"] = "score.evaluation"
|
152 |
+
model: str
|
153 |
+
scores: List[float]
|
llama-factory/src/llamafactory/chat/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .base_engine import BaseEngine
|
16 |
+
from .chat_model import ChatModel
|
17 |
+
|
18 |
+
|
19 |
+
__all__ = ["BaseEngine", "ChatModel"]
|
llama-factory/src/llamafactory/chat/base_engine.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
18 |
+
|
19 |
+
|
20 |
+
if TYPE_CHECKING:
|
21 |
+
from numpy.typing import NDArray
|
22 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
23 |
+
from vllm import AsyncLLMEngine
|
24 |
+
|
25 |
+
from ..data import Template
|
26 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Response:
|
31 |
+
response_text: str
|
32 |
+
response_length: int
|
33 |
+
prompt_length: int
|
34 |
+
finish_reason: Literal["stop", "length"]
|
35 |
+
|
36 |
+
|
37 |
+
class BaseEngine(ABC):
|
38 |
+
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
39 |
+
tokenizer: "PreTrainedTokenizer"
|
40 |
+
can_generate: bool
|
41 |
+
template: "Template"
|
42 |
+
generating_args: Dict[str, Any]
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
model_args: "ModelArguments",
|
48 |
+
data_args: "DataArguments",
|
49 |
+
finetuning_args: "FinetuningArguments",
|
50 |
+
generating_args: "GeneratingArguments",
|
51 |
+
) -> None: ...
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
async def chat(
|
55 |
+
self,
|
56 |
+
messages: Sequence[Dict[str, str]],
|
57 |
+
system: Optional[str] = None,
|
58 |
+
tools: Optional[str] = None,
|
59 |
+
image: Optional["NDArray"] = None,
|
60 |
+
**input_kwargs,
|
61 |
+
) -> List["Response"]: ...
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
async def stream_chat(
|
65 |
+
self,
|
66 |
+
messages: Sequence[Dict[str, str]],
|
67 |
+
system: Optional[str] = None,
|
68 |
+
tools: Optional[str] = None,
|
69 |
+
image: Optional["NDArray"] = None,
|
70 |
+
**input_kwargs,
|
71 |
+
) -> AsyncGenerator[str, None]: ...
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
async def get_scores(
|
75 |
+
self,
|
76 |
+
batch_input: List[str],
|
77 |
+
**input_kwargs,
|
78 |
+
) -> List[float]: ...
|
llama-factory/src/llamafactory/chat/chat_model.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 THUDM and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the THUDM's ChatGLM implementation.
|
4 |
+
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import asyncio
|
19 |
+
import os
|
20 |
+
from threading import Thread
|
21 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
22 |
+
|
23 |
+
from ..extras.misc import torch_gc
|
24 |
+
from ..hparams import get_infer_args
|
25 |
+
from .hf_engine import HuggingfaceEngine
|
26 |
+
from .vllm_engine import VllmEngine
|
27 |
+
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from numpy.typing import NDArray
|
31 |
+
|
32 |
+
from .base_engine import BaseEngine, Response
|
33 |
+
|
34 |
+
|
35 |
+
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
36 |
+
asyncio.set_event_loop(loop)
|
37 |
+
loop.run_forever()
|
38 |
+
|
39 |
+
|
40 |
+
class ChatModel:
|
41 |
+
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
42 |
+
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
43 |
+
if model_args.infer_backend == "huggingface":
|
44 |
+
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
45 |
+
elif model_args.infer_backend == "vllm":
|
46 |
+
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
47 |
+
else:
|
48 |
+
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
49 |
+
|
50 |
+
self._loop = asyncio.new_event_loop()
|
51 |
+
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
52 |
+
self._thread.start()
|
53 |
+
|
54 |
+
def chat(
|
55 |
+
self,
|
56 |
+
messages: Sequence[Dict[str, str]],
|
57 |
+
system: Optional[str] = None,
|
58 |
+
tools: Optional[str] = None,
|
59 |
+
image: Optional["NDArray"] = None,
|
60 |
+
**input_kwargs,
|
61 |
+
) -> List["Response"]:
|
62 |
+
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
63 |
+
return task.result()
|
64 |
+
|
65 |
+
async def achat(
|
66 |
+
self,
|
67 |
+
messages: Sequence[Dict[str, str]],
|
68 |
+
system: Optional[str] = None,
|
69 |
+
tools: Optional[str] = None,
|
70 |
+
image: Optional["NDArray"] = None,
|
71 |
+
**input_kwargs,
|
72 |
+
) -> List["Response"]:
|
73 |
+
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
74 |
+
|
75 |
+
def stream_chat(
|
76 |
+
self,
|
77 |
+
messages: Sequence[Dict[str, str]],
|
78 |
+
system: Optional[str] = None,
|
79 |
+
tools: Optional[str] = None,
|
80 |
+
image: Optional["NDArray"] = None,
|
81 |
+
**input_kwargs,
|
82 |
+
) -> Generator[str, None, None]:
|
83 |
+
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
84 |
+
while True:
|
85 |
+
try:
|
86 |
+
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
87 |
+
yield task.result()
|
88 |
+
except StopAsyncIteration:
|
89 |
+
break
|
90 |
+
|
91 |
+
async def astream_chat(
|
92 |
+
self,
|
93 |
+
messages: Sequence[Dict[str, str]],
|
94 |
+
system: Optional[str] = None,
|
95 |
+
tools: Optional[str] = None,
|
96 |
+
image: Optional["NDArray"] = None,
|
97 |
+
**input_kwargs,
|
98 |
+
) -> AsyncGenerator[str, None]:
|
99 |
+
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
100 |
+
yield new_token
|
101 |
+
|
102 |
+
def get_scores(
|
103 |
+
self,
|
104 |
+
batch_input: List[str],
|
105 |
+
**input_kwargs,
|
106 |
+
) -> List[float]:
|
107 |
+
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
108 |
+
return task.result()
|
109 |
+
|
110 |
+
async def aget_scores(
|
111 |
+
self,
|
112 |
+
batch_input: List[str],
|
113 |
+
**input_kwargs,
|
114 |
+
) -> List[float]:
|
115 |
+
return await self.engine.get_scores(batch_input, **input_kwargs)
|
116 |
+
|
117 |
+
|
118 |
+
def run_chat() -> None:
|
119 |
+
if os.name != "nt":
|
120 |
+
try:
|
121 |
+
import readline # noqa: F401
|
122 |
+
except ImportError:
|
123 |
+
print("Install `readline` for a better experience.")
|
124 |
+
|
125 |
+
chat_model = ChatModel()
|
126 |
+
messages = []
|
127 |
+
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
128 |
+
|
129 |
+
while True:
|
130 |
+
try:
|
131 |
+
query = input("\nUser: ")
|
132 |
+
except UnicodeDecodeError:
|
133 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
134 |
+
continue
|
135 |
+
except Exception:
|
136 |
+
raise
|
137 |
+
|
138 |
+
if query.strip() == "exit":
|
139 |
+
break
|
140 |
+
|
141 |
+
if query.strip() == "clear":
|
142 |
+
messages = []
|
143 |
+
torch_gc()
|
144 |
+
print("History has been removed.")
|
145 |
+
continue
|
146 |
+
|
147 |
+
messages.append({"role": "user", "content": query})
|
148 |
+
print("Assistant: ", end="", flush=True)
|
149 |
+
|
150 |
+
response = ""
|
151 |
+
for new_text in chat_model.stream_chat(messages):
|
152 |
+
print(new_text, end="", flush=True)
|
153 |
+
response += new_text
|
154 |
+
print()
|
155 |
+
messages.append({"role": "assistant", "content": response})
|
llama-factory/src/llamafactory/chat/hf_engine.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import asyncio
|
16 |
+
import concurrent.futures
|
17 |
+
import os
|
18 |
+
from threading import Thread
|
19 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from transformers import GenerationConfig, TextIteratorStreamer
|
23 |
+
|
24 |
+
from ..data import get_template_and_fix_tokenizer
|
25 |
+
from ..extras.logging import get_logger
|
26 |
+
from ..extras.misc import get_logits_processor
|
27 |
+
from ..model import load_model, load_tokenizer
|
28 |
+
from .base_engine import BaseEngine, Response
|
29 |
+
|
30 |
+
|
31 |
+
if TYPE_CHECKING:
|
32 |
+
from numpy.typing import NDArray
|
33 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
34 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
35 |
+
from trl import PreTrainedModelWrapper
|
36 |
+
|
37 |
+
from ..data import Template
|
38 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
39 |
+
|
40 |
+
|
41 |
+
logger = get_logger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
class HuggingfaceEngine(BaseEngine):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
model_args: "ModelArguments",
|
48 |
+
data_args: "DataArguments",
|
49 |
+
finetuning_args: "FinetuningArguments",
|
50 |
+
generating_args: "GeneratingArguments",
|
51 |
+
) -> None:
|
52 |
+
self.can_generate = finetuning_args.stage == "sft"
|
53 |
+
tokenizer_module = load_tokenizer(model_args)
|
54 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
55 |
+
self.processor = tokenizer_module["processor"]
|
56 |
+
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
57 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
58 |
+
self.model = load_model(
|
59 |
+
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
60 |
+
) # must after fixing tokenizer to resize vocab
|
61 |
+
self.generating_args = generating_args.to_dict()
|
62 |
+
try:
|
63 |
+
asyncio.get_event_loop()
|
64 |
+
except RuntimeError:
|
65 |
+
logger.warning("There is no current event loop, creating a new one.")
|
66 |
+
loop = asyncio.new_event_loop()
|
67 |
+
asyncio.set_event_loop(loop)
|
68 |
+
|
69 |
+
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def _process_args(
|
73 |
+
model: "PreTrainedModel",
|
74 |
+
tokenizer: "PreTrainedTokenizer",
|
75 |
+
processor: Optional["ProcessorMixin"],
|
76 |
+
template: "Template",
|
77 |
+
generating_args: Dict[str, Any],
|
78 |
+
messages: Sequence[Dict[str, str]],
|
79 |
+
system: Optional[str] = None,
|
80 |
+
tools: Optional[str] = None,
|
81 |
+
image: Optional["NDArray"] = None,
|
82 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
83 |
+
) -> Tuple[Dict[str, Any], int]:
|
84 |
+
if (
|
85 |
+
processor is not None
|
86 |
+
and image is not None
|
87 |
+
and not hasattr(processor, "image_seq_length")
|
88 |
+
and template.image_token not in messages[0]["content"]
|
89 |
+
): # llava-like models
|
90 |
+
messages[0]["content"] = template.image_token + messages[0]["content"]
|
91 |
+
|
92 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
93 |
+
system = system or generating_args["default_system"]
|
94 |
+
pixel_values = None
|
95 |
+
prompt_ids, _ = template.encode_oneturn(
|
96 |
+
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
97 |
+
)
|
98 |
+
if processor is not None and image is not None: # add image features
|
99 |
+
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
100 |
+
batch_feature = image_processor(image, return_tensors="pt")
|
101 |
+
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
102 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
103 |
+
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
104 |
+
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
105 |
+
|
106 |
+
prompt_length = len(prompt_ids)
|
107 |
+
inputs = torch.tensor([prompt_ids], device=model.device)
|
108 |
+
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
109 |
+
|
110 |
+
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
111 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
112 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
113 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
114 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
115 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
116 |
+
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
117 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
118 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
119 |
+
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
120 |
+
|
121 |
+
if stop is not None:
|
122 |
+
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
123 |
+
|
124 |
+
generating_args = generating_args.copy()
|
125 |
+
generating_args.update(
|
126 |
+
dict(
|
127 |
+
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
128 |
+
temperature=temperature if temperature is not None else generating_args["temperature"],
|
129 |
+
top_p=top_p if top_p is not None else generating_args["top_p"],
|
130 |
+
top_k=top_k if top_k is not None else generating_args["top_k"],
|
131 |
+
num_return_sequences=num_return_sequences,
|
132 |
+
repetition_penalty=repetition_penalty
|
133 |
+
if repetition_penalty is not None
|
134 |
+
else generating_args["repetition_penalty"],
|
135 |
+
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
136 |
+
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
137 |
+
pad_token_id=tokenizer.pad_token_id,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
|
141 |
+
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
|
142 |
+
generating_args["do_sample"] = True
|
143 |
+
generating_args["temperature"] = generating_args["temperature"] or 1.0
|
144 |
+
|
145 |
+
if not generating_args["temperature"]:
|
146 |
+
generating_args["do_sample"] = False
|
147 |
+
|
148 |
+
if not generating_args["do_sample"]:
|
149 |
+
generating_args.pop("temperature", None)
|
150 |
+
generating_args.pop("top_p", None)
|
151 |
+
|
152 |
+
if max_length:
|
153 |
+
generating_args.pop("max_new_tokens", None)
|
154 |
+
generating_args["max_length"] = max_length
|
155 |
+
|
156 |
+
if max_new_tokens:
|
157 |
+
generating_args.pop("max_length", None)
|
158 |
+
generating_args["max_new_tokens"] = max_new_tokens
|
159 |
+
|
160 |
+
gen_kwargs = dict(
|
161 |
+
inputs=inputs,
|
162 |
+
attention_mask=attention_mask,
|
163 |
+
generation_config=GenerationConfig(**generating_args),
|
164 |
+
logits_processor=get_logits_processor(),
|
165 |
+
)
|
166 |
+
|
167 |
+
if pixel_values is not None:
|
168 |
+
gen_kwargs["pixel_values"] = pixel_values
|
169 |
+
|
170 |
+
return gen_kwargs, prompt_length
|
171 |
+
|
172 |
+
@staticmethod
|
173 |
+
@torch.inference_mode()
|
174 |
+
def _chat(
|
175 |
+
model: "PreTrainedModel",
|
176 |
+
tokenizer: "PreTrainedTokenizer",
|
177 |
+
processor: Optional["ProcessorMixin"],
|
178 |
+
template: "Template",
|
179 |
+
generating_args: Dict[str, Any],
|
180 |
+
messages: Sequence[Dict[str, str]],
|
181 |
+
system: Optional[str] = None,
|
182 |
+
tools: Optional[str] = None,
|
183 |
+
image: Optional["NDArray"] = None,
|
184 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
185 |
+
) -> List["Response"]:
|
186 |
+
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
187 |
+
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
188 |
+
)
|
189 |
+
generate_output = model.generate(**gen_kwargs)
|
190 |
+
response_ids = generate_output[:, prompt_length:]
|
191 |
+
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
192 |
+
results = []
|
193 |
+
for i in range(len(response)):
|
194 |
+
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
195 |
+
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
196 |
+
results.append(
|
197 |
+
Response(
|
198 |
+
response_text=response[i],
|
199 |
+
response_length=response_length,
|
200 |
+
prompt_length=prompt_length,
|
201 |
+
finish_reason="stop" if len(eos_index) else "length",
|
202 |
+
)
|
203 |
+
)
|
204 |
+
|
205 |
+
return results
|
206 |
+
|
207 |
+
@staticmethod
|
208 |
+
@torch.inference_mode()
|
209 |
+
def _stream_chat(
|
210 |
+
model: "PreTrainedModel",
|
211 |
+
tokenizer: "PreTrainedTokenizer",
|
212 |
+
processor: Optional["ProcessorMixin"],
|
213 |
+
template: "Template",
|
214 |
+
generating_args: Dict[str, Any],
|
215 |
+
messages: Sequence[Dict[str, str]],
|
216 |
+
system: Optional[str] = None,
|
217 |
+
tools: Optional[str] = None,
|
218 |
+
image: Optional["NDArray"] = None,
|
219 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
220 |
+
) -> Callable[[], str]:
|
221 |
+
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
222 |
+
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
223 |
+
)
|
224 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
225 |
+
gen_kwargs["streamer"] = streamer
|
226 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
227 |
+
thread.start()
|
228 |
+
|
229 |
+
def stream():
|
230 |
+
try:
|
231 |
+
return streamer.__next__()
|
232 |
+
except StopIteration:
|
233 |
+
raise StopAsyncIteration()
|
234 |
+
|
235 |
+
return stream
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
@torch.inference_mode()
|
239 |
+
def _get_scores(
|
240 |
+
model: "PreTrainedModelWrapper",
|
241 |
+
tokenizer: "PreTrainedTokenizer",
|
242 |
+
batch_input: List[str],
|
243 |
+
input_kwargs: Optional[Dict[str, Any]] = {},
|
244 |
+
) -> List[float]:
|
245 |
+
max_length = input_kwargs.pop("max_length", None)
|
246 |
+
device = getattr(model.pretrained_model, "device", "cuda")
|
247 |
+
inputs = tokenizer(
|
248 |
+
batch_input,
|
249 |
+
padding=True,
|
250 |
+
truncation=True,
|
251 |
+
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
252 |
+
return_tensors="pt",
|
253 |
+
add_special_tokens=True,
|
254 |
+
).to(device)
|
255 |
+
|
256 |
+
input_ids: torch.Tensor = inputs["input_ids"]
|
257 |
+
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
258 |
+
|
259 |
+
if getattr(model.config, "model_type", None) == "chatglm":
|
260 |
+
values = torch.transpose(values, 0, 1)
|
261 |
+
|
262 |
+
scores = []
|
263 |
+
for i in range(input_ids.size(0)):
|
264 |
+
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
265 |
+
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
266 |
+
scores.append(values[i, end_index].nan_to_num().item())
|
267 |
+
|
268 |
+
return scores
|
269 |
+
|
270 |
+
async def chat(
|
271 |
+
self,
|
272 |
+
messages: Sequence[Dict[str, str]],
|
273 |
+
system: Optional[str] = None,
|
274 |
+
tools: Optional[str] = None,
|
275 |
+
image: Optional["NDArray"] = None,
|
276 |
+
**input_kwargs,
|
277 |
+
) -> List["Response"]:
|
278 |
+
if not self.can_generate:
|
279 |
+
raise ValueError("The current model does not support `chat`.")
|
280 |
+
|
281 |
+
loop = asyncio.get_running_loop()
|
282 |
+
input_args = (
|
283 |
+
self.model,
|
284 |
+
self.tokenizer,
|
285 |
+
self.processor,
|
286 |
+
self.template,
|
287 |
+
self.generating_args,
|
288 |
+
messages,
|
289 |
+
system,
|
290 |
+
tools,
|
291 |
+
image,
|
292 |
+
input_kwargs,
|
293 |
+
)
|
294 |
+
async with self.semaphore:
|
295 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
296 |
+
return await loop.run_in_executor(pool, self._chat, *input_args)
|
297 |
+
|
298 |
+
async def stream_chat(
|
299 |
+
self,
|
300 |
+
messages: Sequence[Dict[str, str]],
|
301 |
+
system: Optional[str] = None,
|
302 |
+
tools: Optional[str] = None,
|
303 |
+
image: Optional["NDArray"] = None,
|
304 |
+
**input_kwargs,
|
305 |
+
) -> AsyncGenerator[str, None]:
|
306 |
+
if not self.can_generate:
|
307 |
+
raise ValueError("The current model does not support `stream_chat`.")
|
308 |
+
|
309 |
+
loop = asyncio.get_running_loop()
|
310 |
+
input_args = (
|
311 |
+
self.model,
|
312 |
+
self.tokenizer,
|
313 |
+
self.processor,
|
314 |
+
self.template,
|
315 |
+
self.generating_args,
|
316 |
+
messages,
|
317 |
+
system,
|
318 |
+
tools,
|
319 |
+
image,
|
320 |
+
input_kwargs,
|
321 |
+
)
|
322 |
+
async with self.semaphore:
|
323 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
324 |
+
stream = self._stream_chat(*input_args)
|
325 |
+
while True:
|
326 |
+
try:
|
327 |
+
yield await loop.run_in_executor(pool, stream)
|
328 |
+
except StopAsyncIteration:
|
329 |
+
break
|
330 |
+
|
331 |
+
async def get_scores(
|
332 |
+
self,
|
333 |
+
batch_input: List[str],
|
334 |
+
**input_kwargs,
|
335 |
+
) -> List[float]:
|
336 |
+
if self.can_generate:
|
337 |
+
raise ValueError("Cannot get scores using an auto-regressive model.")
|
338 |
+
|
339 |
+
loop = asyncio.get_running_loop()
|
340 |
+
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
341 |
+
async with self.semaphore:
|
342 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
343 |
+
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
llama-factory/src/llamafactory/chat/vllm_engine.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import uuid
|
16 |
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
17 |
+
|
18 |
+
from ..data import get_template_and_fix_tokenizer
|
19 |
+
from ..extras.logging import get_logger
|
20 |
+
from ..extras.misc import get_device_count
|
21 |
+
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
|
22 |
+
from ..model import load_config, load_tokenizer
|
23 |
+
from ..model.model_utils.quantization import QuantizationMethod
|
24 |
+
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
25 |
+
from .base_engine import BaseEngine, Response
|
26 |
+
|
27 |
+
|
28 |
+
if is_vllm_available():
|
29 |
+
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
30 |
+
from vllm.lora.request import LoRARequest
|
31 |
+
|
32 |
+
if is_vllm_version_greater_than_0_5_1():
|
33 |
+
pass
|
34 |
+
elif is_vllm_version_greater_than_0_5():
|
35 |
+
from vllm.multimodal.image import ImagePixelData
|
36 |
+
else:
|
37 |
+
from vllm.sequence import MultiModalData
|
38 |
+
|
39 |
+
|
40 |
+
if TYPE_CHECKING:
|
41 |
+
from numpy.typing import NDArray
|
42 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
43 |
+
|
44 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
45 |
+
|
46 |
+
|
47 |
+
logger = get_logger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
class VllmEngine(BaseEngine):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
model_args: "ModelArguments",
|
54 |
+
data_args: "DataArguments",
|
55 |
+
finetuning_args: "FinetuningArguments",
|
56 |
+
generating_args: "GeneratingArguments",
|
57 |
+
) -> None:
|
58 |
+
config = load_config(model_args) # may download model from ms hub
|
59 |
+
if getattr(config, "quantization_config", None): # gptq models should use float16
|
60 |
+
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
61 |
+
quant_method = quantization_config.get("quant_method", "")
|
62 |
+
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
63 |
+
model_args.infer_dtype = "float16"
|
64 |
+
|
65 |
+
self.can_generate = finetuning_args.stage == "sft"
|
66 |
+
tokenizer_module = load_tokenizer(model_args)
|
67 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
68 |
+
self.processor = tokenizer_module["processor"]
|
69 |
+
self.tokenizer.padding_side = "left"
|
70 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
71 |
+
self.generating_args = generating_args.to_dict()
|
72 |
+
|
73 |
+
engine_args = {
|
74 |
+
"model": model_args.model_name_or_path,
|
75 |
+
"trust_remote_code": True,
|
76 |
+
"download_dir": model_args.cache_dir,
|
77 |
+
"dtype": model_args.infer_dtype,
|
78 |
+
"max_model_len": model_args.vllm_maxlen,
|
79 |
+
"tensor_parallel_size": get_device_count() or 1,
|
80 |
+
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
81 |
+
"disable_log_stats": True,
|
82 |
+
"disable_log_requests": True,
|
83 |
+
"enforce_eager": model_args.vllm_enforce_eager,
|
84 |
+
"enable_lora": model_args.adapter_name_or_path is not None,
|
85 |
+
"max_lora_rank": model_args.vllm_max_lora_rank,
|
86 |
+
}
|
87 |
+
|
88 |
+
if model_args.visual_inputs:
|
89 |
+
image_size = config.vision_config.image_size
|
90 |
+
patch_size = config.vision_config.patch_size
|
91 |
+
self.image_feature_size = (image_size // patch_size) ** 2
|
92 |
+
engine_args["image_input_type"] = "pixel_values"
|
93 |
+
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
|
94 |
+
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
95 |
+
engine_args["image_feature_size"] = self.image_feature_size
|
96 |
+
if getattr(config, "is_yi_vl_derived_model", None):
|
97 |
+
import vllm.model_executor.models.llava
|
98 |
+
|
99 |
+
logger.info("Detected Yi-VL model, applying projector patch.")
|
100 |
+
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
101 |
+
|
102 |
+
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
103 |
+
if model_args.adapter_name_or_path is not None:
|
104 |
+
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
105 |
+
else:
|
106 |
+
self.lora_request = None
|
107 |
+
|
108 |
+
async def _generate(
|
109 |
+
self,
|
110 |
+
messages: Sequence[Dict[str, str]],
|
111 |
+
system: Optional[str] = None,
|
112 |
+
tools: Optional[str] = None,
|
113 |
+
image: Optional["NDArray"] = None,
|
114 |
+
**input_kwargs,
|
115 |
+
) -> AsyncIterator["RequestOutput"]:
|
116 |
+
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
117 |
+
|
118 |
+
if (
|
119 |
+
self.processor is not None
|
120 |
+
and image is not None
|
121 |
+
and not hasattr(self.processor, "image_seq_length")
|
122 |
+
and self.template.image_token not in messages[0]["content"]
|
123 |
+
): # llava-like models (TODO: paligemma models)
|
124 |
+
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
|
125 |
+
|
126 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
127 |
+
system = system or self.generating_args["default_system"]
|
128 |
+
prompt_ids, _ = self.template.encode_oneturn(
|
129 |
+
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
130 |
+
)
|
131 |
+
|
132 |
+
if self.processor is not None and image is not None: # add image features
|
133 |
+
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
134 |
+
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
135 |
+
if is_vllm_version_greater_than_0_5_1():
|
136 |
+
multi_modal_data = {"image": pixel_values}
|
137 |
+
elif is_vllm_version_greater_than_0_5():
|
138 |
+
multi_modal_data = ImagePixelData(image=pixel_values)
|
139 |
+
else: # TODO: remove vllm 0.4.3 support
|
140 |
+
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
141 |
+
else:
|
142 |
+
multi_modal_data = None
|
143 |
+
|
144 |
+
prompt_length = len(prompt_ids)
|
145 |
+
|
146 |
+
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
147 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
148 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
149 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
150 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
151 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
152 |
+
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
153 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
154 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
155 |
+
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
156 |
+
|
157 |
+
if "max_new_tokens" in self.generating_args:
|
158 |
+
max_tokens = self.generating_args["max_new_tokens"]
|
159 |
+
elif "max_length" in self.generating_args:
|
160 |
+
if self.generating_args["max_length"] > prompt_length:
|
161 |
+
max_tokens = self.generating_args["max_length"] - prompt_length
|
162 |
+
else:
|
163 |
+
max_tokens = 1
|
164 |
+
|
165 |
+
if max_length:
|
166 |
+
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
167 |
+
|
168 |
+
if max_new_tokens:
|
169 |
+
max_tokens = max_new_tokens
|
170 |
+
|
171 |
+
sampling_params = SamplingParams(
|
172 |
+
n=num_return_sequences,
|
173 |
+
repetition_penalty=(
|
174 |
+
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
175 |
+
)
|
176 |
+
or 1.0, # repetition_penalty must > 0
|
177 |
+
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
178 |
+
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
179 |
+
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
180 |
+
use_beam_search=use_beam_search,
|
181 |
+
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
182 |
+
stop=stop,
|
183 |
+
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
184 |
+
max_tokens=max_tokens,
|
185 |
+
skip_special_tokens=True,
|
186 |
+
)
|
187 |
+
|
188 |
+
result_generator = self.model.generate(
|
189 |
+
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
190 |
+
sampling_params=sampling_params,
|
191 |
+
request_id=request_id,
|
192 |
+
lora_request=self.lora_request,
|
193 |
+
)
|
194 |
+
return result_generator
|
195 |
+
|
196 |
+
async def chat(
|
197 |
+
self,
|
198 |
+
messages: Sequence[Dict[str, str]],
|
199 |
+
system: Optional[str] = None,
|
200 |
+
tools: Optional[str] = None,
|
201 |
+
image: Optional["NDArray"] = None,
|
202 |
+
**input_kwargs,
|
203 |
+
) -> List["Response"]:
|
204 |
+
final_output = None
|
205 |
+
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
206 |
+
async for request_output in generator:
|
207 |
+
final_output = request_output
|
208 |
+
|
209 |
+
results = []
|
210 |
+
for output in final_output.outputs:
|
211 |
+
results.append(
|
212 |
+
Response(
|
213 |
+
response_text=output.text,
|
214 |
+
response_length=len(output.token_ids),
|
215 |
+
prompt_length=len(final_output.prompt_token_ids),
|
216 |
+
finish_reason=output.finish_reason,
|
217 |
+
)
|
218 |
+
)
|
219 |
+
|
220 |
+
return results
|
221 |
+
|
222 |
+
async def stream_chat(
|
223 |
+
self,
|
224 |
+
messages: Sequence[Dict[str, str]],
|
225 |
+
system: Optional[str] = None,
|
226 |
+
tools: Optional[str] = None,
|
227 |
+
image: Optional["NDArray"] = None,
|
228 |
+
**input_kwargs,
|
229 |
+
) -> AsyncGenerator[str, None]:
|
230 |
+
generated_text = ""
|
231 |
+
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
232 |
+
async for result in generator:
|
233 |
+
delta_text = result.outputs[0].text[len(generated_text) :]
|
234 |
+
generated_text = result.outputs[0].text
|
235 |
+
yield delta_text
|
236 |
+
|
237 |
+
async def get_scores(
|
238 |
+
self,
|
239 |
+
batch_input: List[str],
|
240 |
+
**input_kwargs,
|
241 |
+
) -> List[float]:
|
242 |
+
raise NotImplementedError("vLLM engine does not support get_scores.")
|
llama-factory/src/llamafactory/cli.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import random
|
17 |
+
import subprocess
|
18 |
+
import sys
|
19 |
+
from enum import Enum, unique
|
20 |
+
|
21 |
+
from . import launcher
|
22 |
+
from .api.app import run_api
|
23 |
+
from .chat.chat_model import run_chat
|
24 |
+
from .eval.evaluator import run_eval
|
25 |
+
from .extras.env import VERSION, print_env
|
26 |
+
from .extras.logging import get_logger
|
27 |
+
from .extras.misc import get_device_count
|
28 |
+
from .train.tuner import export_model, run_exp
|
29 |
+
from .webui.interface import run_web_demo, run_web_ui
|
30 |
+
|
31 |
+
|
32 |
+
USAGE = (
|
33 |
+
"-" * 70
|
34 |
+
+ "\n"
|
35 |
+
+ "| Usage: |\n"
|
36 |
+
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
37 |
+
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
38 |
+
+ "| llamafactory-cli eval -h: evaluate models |\n"
|
39 |
+
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
40 |
+
+ "| llamafactory-cli train -h: train models |\n"
|
41 |
+
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
42 |
+
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
43 |
+
+ "| llamafactory-cli version: show version info |\n"
|
44 |
+
+ "-" * 70
|
45 |
+
)
|
46 |
+
|
47 |
+
WELCOME = (
|
48 |
+
"-" * 58
|
49 |
+
+ "\n"
|
50 |
+
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
51 |
+
+ " " * (21 - len(VERSION))
|
52 |
+
+ "|\n|"
|
53 |
+
+ " " * 56
|
54 |
+
+ "|\n"
|
55 |
+
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
56 |
+
+ "-" * 58
|
57 |
+
)
|
58 |
+
|
59 |
+
logger = get_logger(__name__)
|
60 |
+
|
61 |
+
|
62 |
+
@unique
|
63 |
+
class Command(str, Enum):
|
64 |
+
API = "api"
|
65 |
+
CHAT = "chat"
|
66 |
+
ENV = "env"
|
67 |
+
EVAL = "eval"
|
68 |
+
EXPORT = "export"
|
69 |
+
TRAIN = "train"
|
70 |
+
WEBDEMO = "webchat"
|
71 |
+
WEBUI = "webui"
|
72 |
+
VER = "version"
|
73 |
+
HELP = "help"
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
78 |
+
if command == Command.API:
|
79 |
+
run_api()
|
80 |
+
elif command == Command.CHAT:
|
81 |
+
run_chat()
|
82 |
+
elif command == Command.ENV:
|
83 |
+
print_env()
|
84 |
+
elif command == Command.EVAL:
|
85 |
+
run_eval()
|
86 |
+
elif command == Command.EXPORT:
|
87 |
+
export_model()
|
88 |
+
elif command == Command.TRAIN:
|
89 |
+
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
90 |
+
if force_torchrun or get_device_count() > 1:
|
91 |
+
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
92 |
+
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
93 |
+
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
94 |
+
process = subprocess.run(
|
95 |
+
(
|
96 |
+
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
97 |
+
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
98 |
+
).format(
|
99 |
+
nnodes=os.environ.get("NNODES", "1"),
|
100 |
+
node_rank=os.environ.get("RANK", "0"),
|
101 |
+
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
102 |
+
master_addr=master_addr,
|
103 |
+
master_port=master_port,
|
104 |
+
file_name=launcher.__file__,
|
105 |
+
args=" ".join(sys.argv[1:]),
|
106 |
+
),
|
107 |
+
shell=True,
|
108 |
+
)
|
109 |
+
sys.exit(process.returncode)
|
110 |
+
else:
|
111 |
+
run_exp()
|
112 |
+
elif command == Command.WEBDEMO:
|
113 |
+
run_web_demo()
|
114 |
+
elif command == Command.WEBUI:
|
115 |
+
run_web_ui()
|
116 |
+
elif command == Command.VER:
|
117 |
+
print(WELCOME)
|
118 |
+
elif command == Command.HELP:
|
119 |
+
print(USAGE)
|
120 |
+
else:
|
121 |
+
raise NotImplementedError("Unknown command: {}".format(command))
|
llama-factory/src/llamafactory/data/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
|
16 |
+
from .data_utils import Role, split_dataset
|
17 |
+
from .loader import get_dataset
|
18 |
+
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
19 |
+
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"KTODataCollatorWithPadding",
|
23 |
+
"PairwiseDataCollatorWithPadding",
|
24 |
+
"SFTDataCollatorWith4DAttentionMask",
|
25 |
+
"Role",
|
26 |
+
"split_dataset",
|
27 |
+
"get_dataset",
|
28 |
+
"TEMPLATES",
|
29 |
+
"Template",
|
30 |
+
"get_template_and_fix_tokenizer",
|
31 |
+
]
|
llama-factory/src/llamafactory/data/aligner.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from functools import partial
|
17 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
18 |
+
|
19 |
+
from datasets import Features
|
20 |
+
|
21 |
+
from ..extras.logging import get_logger
|
22 |
+
from .data_utils import Role
|
23 |
+
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from datasets import Dataset, IterableDataset
|
27 |
+
from transformers import Seq2SeqTrainingArguments
|
28 |
+
|
29 |
+
from ..hparams import DataArguments
|
30 |
+
from .parser import DatasetAttr
|
31 |
+
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
37 |
+
r"""
|
38 |
+
Optionally concatenates image path to dataset dir when loading from local disk.
|
39 |
+
"""
|
40 |
+
outputs = []
|
41 |
+
if dataset_attr.load_from in ["script", "file"]:
|
42 |
+
for image in images:
|
43 |
+
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
44 |
+
outputs.append(os.path.join(data_args.dataset_dir, image))
|
45 |
+
else:
|
46 |
+
outputs.append(image)
|
47 |
+
|
48 |
+
return outputs
|
49 |
+
|
50 |
+
|
51 |
+
def convert_alpaca(
|
52 |
+
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
53 |
+
) -> Dict[str, List[Any]]:
|
54 |
+
r"""
|
55 |
+
Converts alpaca format dataset to the standard format.
|
56 |
+
"""
|
57 |
+
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
58 |
+
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
59 |
+
for i in range(len(examples[dataset_attr.prompt])):
|
60 |
+
prompt = []
|
61 |
+
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
62 |
+
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
63 |
+
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
64 |
+
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
65 |
+
|
66 |
+
content = []
|
67 |
+
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
68 |
+
content.append(examples[dataset_attr.prompt][i])
|
69 |
+
|
70 |
+
if dataset_attr.query and examples[dataset_attr.query][i]:
|
71 |
+
content.append(examples[dataset_attr.query][i])
|
72 |
+
|
73 |
+
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
74 |
+
|
75 |
+
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
76 |
+
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
77 |
+
if examples[dataset_attr.kto_tag][i]:
|
78 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
79 |
+
else:
|
80 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
81 |
+
elif (
|
82 |
+
dataset_attr.ranking
|
83 |
+
and isinstance(examples[dataset_attr.chosen][i], str)
|
84 |
+
and isinstance(examples[dataset_attr.rejected][i], str)
|
85 |
+
): # pairwise example
|
86 |
+
response = [
|
87 |
+
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
88 |
+
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
89 |
+
]
|
90 |
+
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
91 |
+
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
92 |
+
else: # unsupervised
|
93 |
+
response = []
|
94 |
+
|
95 |
+
outputs["prompt"].append(prompt)
|
96 |
+
outputs["response"].append(response)
|
97 |
+
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
98 |
+
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
99 |
+
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
100 |
+
|
101 |
+
return outputs
|
102 |
+
|
103 |
+
|
104 |
+
def convert_sharegpt(
|
105 |
+
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
106 |
+
) -> Dict[str, List[Any]]:
|
107 |
+
r"""
|
108 |
+
Converts sharegpt format dataset to the standard format.
|
109 |
+
"""
|
110 |
+
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
111 |
+
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
112 |
+
tag_mapping = {
|
113 |
+
dataset_attr.user_tag: Role.USER.value,
|
114 |
+
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
115 |
+
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
116 |
+
dataset_attr.function_tag: Role.FUNCTION.value,
|
117 |
+
dataset_attr.system_tag: Role.SYSTEM.value,
|
118 |
+
}
|
119 |
+
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
120 |
+
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
121 |
+
accept_tags = (odd_tags, even_tags)
|
122 |
+
for i, messages in enumerate(examples[dataset_attr.messages]):
|
123 |
+
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
124 |
+
system = messages[0][dataset_attr.content_tag]
|
125 |
+
messages = messages[1:]
|
126 |
+
else:
|
127 |
+
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
128 |
+
|
129 |
+
if len(messages) == 0:
|
130 |
+
continue
|
131 |
+
|
132 |
+
aligned_messages = []
|
133 |
+
broken_data = False
|
134 |
+
for turn_idx, message in enumerate(messages):
|
135 |
+
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
136 |
+
logger.warning("Invalid role tag in {}.".format(messages))
|
137 |
+
broken_data = True
|
138 |
+
|
139 |
+
aligned_messages.append(
|
140 |
+
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
141 |
+
)
|
142 |
+
|
143 |
+
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
144 |
+
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
145 |
+
):
|
146 |
+
logger.warning("Invalid message count in {}.".format(messages))
|
147 |
+
broken_data = True
|
148 |
+
|
149 |
+
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
150 |
+
prompt = aligned_messages[:-1]
|
151 |
+
response = aligned_messages[-1:]
|
152 |
+
if examples[dataset_attr.kto_tag][i]:
|
153 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
154 |
+
else:
|
155 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
156 |
+
elif (
|
157 |
+
dataset_attr.ranking
|
158 |
+
and isinstance(examples[dataset_attr.chosen][i], dict)
|
159 |
+
and isinstance(examples[dataset_attr.rejected][i], dict)
|
160 |
+
): # pairwise example
|
161 |
+
chosen = examples[dataset_attr.chosen][i]
|
162 |
+
rejected = examples[dataset_attr.rejected][i]
|
163 |
+
if (
|
164 |
+
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
165 |
+
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
166 |
+
):
|
167 |
+
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
168 |
+
broken_data = True
|
169 |
+
|
170 |
+
prompt = aligned_messages
|
171 |
+
response = [
|
172 |
+
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
173 |
+
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
174 |
+
]
|
175 |
+
else: # normal example
|
176 |
+
prompt = aligned_messages[:-1]
|
177 |
+
response = aligned_messages[-1:]
|
178 |
+
|
179 |
+
if broken_data:
|
180 |
+
logger.warning("Skipping this abnormal example.")
|
181 |
+
continue
|
182 |
+
|
183 |
+
outputs["prompt"].append(prompt)
|
184 |
+
outputs["response"].append(response)
|
185 |
+
outputs["system"].append(system)
|
186 |
+
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
187 |
+
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
188 |
+
|
189 |
+
return outputs
|
190 |
+
|
191 |
+
|
192 |
+
def align_dataset(
|
193 |
+
dataset: Union["Dataset", "IterableDataset"],
|
194 |
+
dataset_attr: "DatasetAttr",
|
195 |
+
data_args: "DataArguments",
|
196 |
+
training_args: "Seq2SeqTrainingArguments",
|
197 |
+
) -> Union["Dataset", "IterableDataset"]:
|
198 |
+
r"""
|
199 |
+
Aligned dataset:
|
200 |
+
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
201 |
+
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
202 |
+
system: "..."
|
203 |
+
tools: "...",
|
204 |
+
images: [],
|
205 |
+
"""
|
206 |
+
if dataset_attr.formatting == "alpaca":
|
207 |
+
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
208 |
+
else:
|
209 |
+
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
210 |
+
|
211 |
+
column_names = list(next(iter(dataset)).keys())
|
212 |
+
features = Features.from_dict(
|
213 |
+
{
|
214 |
+
"prompt": [
|
215 |
+
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
216 |
+
],
|
217 |
+
"response": [
|
218 |
+
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
219 |
+
],
|
220 |
+
"system": {"dtype": "string", "_type": "Value"},
|
221 |
+
"tools": {"dtype": "string", "_type": "Value"},
|
222 |
+
"images": [{"_type": "Image"}],
|
223 |
+
}
|
224 |
+
)
|
225 |
+
kwargs = {}
|
226 |
+
if not data_args.streaming:
|
227 |
+
kwargs = dict(
|
228 |
+
num_proc=data_args.preprocessing_num_workers,
|
229 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
230 |
+
desc="Converting format of dataset",
|
231 |
+
)
|
232 |
+
|
233 |
+
return dataset.map(
|
234 |
+
convert_func,
|
235 |
+
batched=True,
|
236 |
+
remove_columns=column_names,
|
237 |
+
features=features,
|
238 |
+
**kwargs,
|
239 |
+
)
|
llama-factory/src/llamafactory/data/collator.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
4 |
+
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Any, Dict, Literal, Sequence
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from transformers import DataCollatorForSeq2Seq
|
23 |
+
|
24 |
+
|
25 |
+
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
26 |
+
r"""
|
27 |
+
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
28 |
+
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
29 |
+
|
30 |
+
e.g.
|
31 |
+
```python
|
32 |
+
# input
|
33 |
+
[[1, 1, 2, 2, 2, 0]]
|
34 |
+
# output
|
35 |
+
[
|
36 |
+
[
|
37 |
+
[
|
38 |
+
[o, x, x, x, x, x],
|
39 |
+
[o, o, x, x, x, x],
|
40 |
+
[x, x, o, x, x, x],
|
41 |
+
[x, x, o, o, x, x],
|
42 |
+
[x, x, o, o, o, x],
|
43 |
+
[x, x, x, x, x, x],
|
44 |
+
]
|
45 |
+
]
|
46 |
+
]
|
47 |
+
```
|
48 |
+
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
49 |
+
"""
|
50 |
+
bsz, seq_len = attention_mask_with_indices.size()
|
51 |
+
min_dtype = torch.finfo(dtype).min
|
52 |
+
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
53 |
+
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
54 |
+
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
55 |
+
# Create a block-diagonal mask.
|
56 |
+
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
57 |
+
# Use the lower triangular mask to zero out the upper triangular part
|
58 |
+
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
59 |
+
# Invert the attention mask.
|
60 |
+
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
61 |
+
return attention_mask_4d
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
66 |
+
r"""
|
67 |
+
Data collator for 4d attention mask.
|
68 |
+
"""
|
69 |
+
|
70 |
+
block_diag_attn: bool = False
|
71 |
+
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
72 |
+
compute_dtype: "torch.dtype" = torch.float32
|
73 |
+
|
74 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
75 |
+
features = super().__call__(features)
|
76 |
+
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
77 |
+
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
78 |
+
|
79 |
+
return features
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
84 |
+
r"""
|
85 |
+
Data collator for pairwise data.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
89 |
+
r"""
|
90 |
+
Pads batched data to the longest sequence in the batch.
|
91 |
+
|
92 |
+
We generate 2 * n examples where the first n examples represent chosen examples and
|
93 |
+
the last n examples represent rejected examples.
|
94 |
+
"""
|
95 |
+
concatenated_features = []
|
96 |
+
for key in ("chosen", "rejected"):
|
97 |
+
for feature in features:
|
98 |
+
target_feature = {
|
99 |
+
"input_ids": feature["{}_input_ids".format(key)],
|
100 |
+
"attention_mask": feature["{}_attention_mask".format(key)],
|
101 |
+
"labels": feature["{}_labels".format(key)],
|
102 |
+
}
|
103 |
+
if "pixel_values" in feature:
|
104 |
+
target_feature["pixel_values"] = feature["pixel_values"]
|
105 |
+
|
106 |
+
if "{}_token_type_ids".format(key) in feature:
|
107 |
+
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
108 |
+
|
109 |
+
concatenated_features.append(target_feature)
|
110 |
+
|
111 |
+
return super().__call__(concatenated_features)
|
112 |
+
|
113 |
+
|
114 |
+
@dataclass
|
115 |
+
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
116 |
+
r"""
|
117 |
+
Data collator for KTO data.
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
121 |
+
target_features = []
|
122 |
+
kl_features = []
|
123 |
+
kto_tags = []
|
124 |
+
for feature in features:
|
125 |
+
target_feature = {
|
126 |
+
"input_ids": feature["input_ids"],
|
127 |
+
"attention_mask": feature["attention_mask"],
|
128 |
+
"labels": feature["labels"],
|
129 |
+
}
|
130 |
+
kl_feature = {
|
131 |
+
"input_ids": feature["kl_input_ids"],
|
132 |
+
"attention_mask": feature["kl_attention_mask"],
|
133 |
+
"labels": feature["kl_labels"],
|
134 |
+
}
|
135 |
+
if "pixel_values" in feature:
|
136 |
+
target_feature["pixel_values"] = feature["pixel_values"]
|
137 |
+
|
138 |
+
if "token_type_ids" in feature:
|
139 |
+
target_feature["token_type_ids"] = feature["token_type_ids"]
|
140 |
+
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
141 |
+
|
142 |
+
target_features.append(target_feature)
|
143 |
+
kl_features.append(kl_feature)
|
144 |
+
kto_tags.append(feature["kto_tags"])
|
145 |
+
|
146 |
+
batch = super().__call__(target_features)
|
147 |
+
kl_batch = super().__call__(kl_features)
|
148 |
+
batch["kl_input_ids"] = kl_batch["input_ids"]
|
149 |
+
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
150 |
+
batch["kl_labels"] = kl_batch["labels"]
|
151 |
+
if "token_type_ids" in batch:
|
152 |
+
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
153 |
+
|
154 |
+
batch["kto_tags"] = torch.tensor(kto_tags)
|
155 |
+
return batch
|
llama-factory/src/llamafactory/data/data_utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from enum import Enum, unique
|
16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
17 |
+
|
18 |
+
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
19 |
+
|
20 |
+
from ..extras.logging import get_logger
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from datasets import Dataset, IterableDataset
|
25 |
+
|
26 |
+
from ..hparams import DataArguments
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
33 |
+
|
34 |
+
|
35 |
+
@unique
|
36 |
+
class Role(str, Enum):
|
37 |
+
USER = "user"
|
38 |
+
ASSISTANT = "assistant"
|
39 |
+
SYSTEM = "system"
|
40 |
+
FUNCTION = "function"
|
41 |
+
OBSERVATION = "observation"
|
42 |
+
|
43 |
+
|
44 |
+
class DatasetModule(TypedDict):
|
45 |
+
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
46 |
+
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
47 |
+
|
48 |
+
|
49 |
+
def merge_dataset(
|
50 |
+
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
51 |
+
) -> Union["Dataset", "IterableDataset"]:
|
52 |
+
if len(all_datasets) == 1:
|
53 |
+
return all_datasets[0]
|
54 |
+
elif data_args.mix_strategy == "concat":
|
55 |
+
if data_args.streaming:
|
56 |
+
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
57 |
+
|
58 |
+
return concatenate_datasets(all_datasets)
|
59 |
+
elif data_args.mix_strategy.startswith("interleave"):
|
60 |
+
if not data_args.streaming:
|
61 |
+
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
62 |
+
|
63 |
+
return interleave_datasets(
|
64 |
+
datasets=all_datasets,
|
65 |
+
probabilities=data_args.interleave_probs,
|
66 |
+
seed=seed,
|
67 |
+
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
raise ValueError("Unknown mixing strategy.")
|
71 |
+
|
72 |
+
|
73 |
+
def split_dataset(
|
74 |
+
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
75 |
+
) -> "DatasetDict":
|
76 |
+
r"""
|
77 |
+
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
|
78 |
+
"""
|
79 |
+
if data_args.streaming:
|
80 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
81 |
+
val_set = dataset.take(int(data_args.val_size))
|
82 |
+
train_set = dataset.skip(int(data_args.val_size))
|
83 |
+
return DatasetDict({"train": train_set, "validation": val_set})
|
84 |
+
else:
|
85 |
+
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
86 |
+
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
87 |
+
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
|
llama-factory/src/llamafactory/data/formatter.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import List, Literal, Optional, Tuple, Union
|
20 |
+
|
21 |
+
from .data_utils import SLOTS
|
22 |
+
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class Formatter(ABC):
|
27 |
+
slots: SLOTS = field(default_factory=list)
|
28 |
+
tool_format: Optional[Literal["default", "glm4"]] = None
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def apply(self, **kwargs) -> SLOTS: ...
|
32 |
+
|
33 |
+
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class EmptyFormatter(Formatter):
|
39 |
+
def __post_init__(self):
|
40 |
+
has_placeholder = False
|
41 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
42 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
43 |
+
has_placeholder = True
|
44 |
+
|
45 |
+
if has_placeholder:
|
46 |
+
raise ValueError("Empty formatter should not contain any placeholder.")
|
47 |
+
|
48 |
+
def apply(self, **kwargs) -> SLOTS:
|
49 |
+
return self.slots
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class StringFormatter(Formatter):
|
54 |
+
def __post_init__(self):
|
55 |
+
has_placeholder = False
|
56 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
57 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
58 |
+
has_placeholder = True
|
59 |
+
|
60 |
+
if not has_placeholder:
|
61 |
+
raise ValueError("A placeholder is required in the string formatter.")
|
62 |
+
|
63 |
+
def apply(self, **kwargs) -> SLOTS:
|
64 |
+
elements = []
|
65 |
+
for slot in self.slots:
|
66 |
+
if isinstance(slot, str):
|
67 |
+
for name, value in kwargs.items():
|
68 |
+
if not isinstance(value, str):
|
69 |
+
raise RuntimeError("Expected a string, got {}".format(value))
|
70 |
+
|
71 |
+
slot = slot.replace("{{" + name + "}}", value, 1)
|
72 |
+
elements.append(slot)
|
73 |
+
elif isinstance(slot, (dict, set)):
|
74 |
+
elements.append(slot)
|
75 |
+
else:
|
76 |
+
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
77 |
+
|
78 |
+
return elements
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class FunctionFormatter(Formatter):
|
83 |
+
def __post_init__(self):
|
84 |
+
if self.tool_format == "default":
|
85 |
+
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
86 |
+
elif self.tool_format == "glm4":
|
87 |
+
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
88 |
+
else:
|
89 |
+
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
90 |
+
|
91 |
+
def apply(self, **kwargs) -> SLOTS:
|
92 |
+
content = kwargs.pop("content")
|
93 |
+
functions: List[Tuple[str, str]] = []
|
94 |
+
try:
|
95 |
+
tool_calls = json.loads(content)
|
96 |
+
if not isinstance(tool_calls, list): # parallel function call
|
97 |
+
tool_calls = [tool_calls]
|
98 |
+
|
99 |
+
for tool_call in tool_calls:
|
100 |
+
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
101 |
+
|
102 |
+
except json.JSONDecodeError:
|
103 |
+
functions = []
|
104 |
+
|
105 |
+
elements = []
|
106 |
+
for name, arguments in functions:
|
107 |
+
for slot in self.slots:
|
108 |
+
if isinstance(slot, str):
|
109 |
+
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
110 |
+
elements.append(slot)
|
111 |
+
elif isinstance(slot, (dict, set)):
|
112 |
+
elements.append(slot)
|
113 |
+
else:
|
114 |
+
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
115 |
+
|
116 |
+
return elements
|
117 |
+
|
118 |
+
|
119 |
+
@dataclass
|
120 |
+
class ToolFormatter(Formatter):
|
121 |
+
def __post_init__(self):
|
122 |
+
if self.tool_format == "default":
|
123 |
+
self._tool_formatter = DefaultToolUtils.tool_formatter
|
124 |
+
self._tool_extractor = DefaultToolUtils.tool_extractor
|
125 |
+
elif self.tool_format == "glm4":
|
126 |
+
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
127 |
+
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
128 |
+
else:
|
129 |
+
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
130 |
+
|
131 |
+
def apply(self, **kwargs) -> SLOTS:
|
132 |
+
content = kwargs.pop("content")
|
133 |
+
try:
|
134 |
+
tools = json.loads(content)
|
135 |
+
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
136 |
+
except json.JSONDecodeError:
|
137 |
+
return [""]
|
138 |
+
|
139 |
+
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
140 |
+
return self._tool_extractor(content)
|
llama-factory/src/llamafactory/data/loader.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
from datasets import DatasetDict, load_dataset, load_from_disk
|
21 |
+
from transformers.utils.versions import require_version
|
22 |
+
|
23 |
+
from ..extras.constants import FILEEXT2TYPE
|
24 |
+
from ..extras.logging import get_logger
|
25 |
+
from ..extras.misc import has_tokenized_data
|
26 |
+
from .aligner import align_dataset
|
27 |
+
from .data_utils import merge_dataset, split_dataset
|
28 |
+
from .parser import get_dataset_list
|
29 |
+
from .preprocess import get_preprocess_and_print_func
|
30 |
+
from .template import get_template_and_fix_tokenizer
|
31 |
+
|
32 |
+
|
33 |
+
if TYPE_CHECKING:
|
34 |
+
from datasets import Dataset, IterableDataset
|
35 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
36 |
+
|
37 |
+
from ..hparams import DataArguments, ModelArguments
|
38 |
+
from .data_utils import DatasetModule
|
39 |
+
from .parser import DatasetAttr
|
40 |
+
from .template import Template
|
41 |
+
|
42 |
+
|
43 |
+
logger = get_logger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
def _load_single_dataset(
|
47 |
+
dataset_attr: "DatasetAttr",
|
48 |
+
model_args: "ModelArguments",
|
49 |
+
data_args: "DataArguments",
|
50 |
+
training_args: "Seq2SeqTrainingArguments",
|
51 |
+
) -> Union["Dataset", "IterableDataset"]:
|
52 |
+
logger.info("Loading dataset {}...".format(dataset_attr))
|
53 |
+
data_path, data_name, data_dir, data_files = None, None, None, None
|
54 |
+
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
55 |
+
data_path = dataset_attr.dataset_name
|
56 |
+
data_name = dataset_attr.subset
|
57 |
+
data_dir = dataset_attr.folder
|
58 |
+
|
59 |
+
elif dataset_attr.load_from == "script":
|
60 |
+
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
61 |
+
data_name = dataset_attr.subset
|
62 |
+
data_dir = dataset_attr.folder
|
63 |
+
|
64 |
+
elif dataset_attr.load_from == "file":
|
65 |
+
data_files = []
|
66 |
+
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
67 |
+
if os.path.isdir(local_path): # is directory
|
68 |
+
for file_name in os.listdir(local_path):
|
69 |
+
data_files.append(os.path.join(local_path, file_name))
|
70 |
+
if data_path is None:
|
71 |
+
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
72 |
+
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
73 |
+
raise ValueError("File types should be identical.")
|
74 |
+
elif os.path.isfile(local_path): # is file
|
75 |
+
data_files.append(local_path)
|
76 |
+
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
77 |
+
else:
|
78 |
+
raise ValueError("File {} not found.".format(local_path))
|
79 |
+
|
80 |
+
if data_path is None:
|
81 |
+
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
82 |
+
else:
|
83 |
+
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
84 |
+
|
85 |
+
if dataset_attr.load_from == "ms_hub":
|
86 |
+
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
87 |
+
from modelscope import MsDataset
|
88 |
+
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
89 |
+
|
90 |
+
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
91 |
+
dataset = MsDataset.load(
|
92 |
+
dataset_name=data_path,
|
93 |
+
subset_name=data_name,
|
94 |
+
data_dir=data_dir,
|
95 |
+
data_files=data_files,
|
96 |
+
split=dataset_attr.split,
|
97 |
+
cache_dir=cache_dir,
|
98 |
+
token=model_args.ms_hub_token,
|
99 |
+
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
100 |
+
)
|
101 |
+
if isinstance(dataset, MsDataset):
|
102 |
+
dataset = dataset.to_hf_dataset()
|
103 |
+
else:
|
104 |
+
dataset = load_dataset(
|
105 |
+
path=data_path,
|
106 |
+
name=data_name,
|
107 |
+
data_dir=data_dir,
|
108 |
+
data_files=data_files,
|
109 |
+
split=dataset_attr.split,
|
110 |
+
cache_dir=model_args.cache_dir,
|
111 |
+
token=model_args.hf_hub_token,
|
112 |
+
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
113 |
+
trust_remote_code=True,
|
114 |
+
)
|
115 |
+
|
116 |
+
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
117 |
+
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
118 |
+
|
119 |
+
if dataset_attr.num_samples is not None and not data_args.streaming:
|
120 |
+
target_num = dataset_attr.num_samples
|
121 |
+
indexes = np.random.permutation(len(dataset))[:target_num]
|
122 |
+
target_num -= len(indexes)
|
123 |
+
if target_num > 0:
|
124 |
+
expand_indexes = np.random.choice(len(dataset), target_num)
|
125 |
+
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
126 |
+
|
127 |
+
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
128 |
+
dataset = dataset.select(indexes)
|
129 |
+
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
130 |
+
|
131 |
+
if data_args.max_samples is not None: # truncate dataset
|
132 |
+
max_samples = min(data_args.max_samples, len(dataset))
|
133 |
+
dataset = dataset.select(range(max_samples))
|
134 |
+
|
135 |
+
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
136 |
+
|
137 |
+
|
138 |
+
def _get_merged_dataset(
|
139 |
+
dataset_names: Optional[Sequence[str]],
|
140 |
+
model_args: "ModelArguments",
|
141 |
+
data_args: "DataArguments",
|
142 |
+
training_args: "Seq2SeqTrainingArguments",
|
143 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
144 |
+
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
145 |
+
if dataset_names is None:
|
146 |
+
return None
|
147 |
+
|
148 |
+
datasets = []
|
149 |
+
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
150 |
+
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
151 |
+
raise ValueError("The dataset is not applicable in the current training stage.")
|
152 |
+
|
153 |
+
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
154 |
+
|
155 |
+
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
156 |
+
|
157 |
+
|
158 |
+
def _get_preprocessed_dataset(
|
159 |
+
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
160 |
+
data_args: "DataArguments",
|
161 |
+
training_args: "Seq2SeqTrainingArguments",
|
162 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
163 |
+
template: "Template",
|
164 |
+
tokenizer: "PreTrainedTokenizer",
|
165 |
+
processor: Optional["ProcessorMixin"] = None,
|
166 |
+
is_eval: bool = False,
|
167 |
+
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
168 |
+
if dataset is None:
|
169 |
+
return None
|
170 |
+
|
171 |
+
preprocess_func, print_function = get_preprocess_and_print_func(
|
172 |
+
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
173 |
+
)
|
174 |
+
column_names = list(next(iter(dataset)).keys())
|
175 |
+
kwargs = {}
|
176 |
+
if not data_args.streaming:
|
177 |
+
kwargs = dict(
|
178 |
+
num_proc=data_args.preprocessing_num_workers,
|
179 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
180 |
+
desc="Running tokenizer on dataset",
|
181 |
+
)
|
182 |
+
|
183 |
+
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
184 |
+
|
185 |
+
if training_args.should_log:
|
186 |
+
try:
|
187 |
+
print("eval example:" if is_eval else "training example:")
|
188 |
+
print_function(next(iter(dataset)))
|
189 |
+
except StopIteration:
|
190 |
+
if stage == "pt":
|
191 |
+
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
192 |
+
else:
|
193 |
+
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
194 |
+
|
195 |
+
return dataset
|
196 |
+
|
197 |
+
|
198 |
+
def get_dataset(
|
199 |
+
model_args: "ModelArguments",
|
200 |
+
data_args: "DataArguments",
|
201 |
+
training_args: "Seq2SeqTrainingArguments",
|
202 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
203 |
+
tokenizer: "PreTrainedTokenizer",
|
204 |
+
processor: Optional["ProcessorMixin"] = None,
|
205 |
+
) -> "DatasetModule":
|
206 |
+
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
207 |
+
if data_args.train_on_prompt and template.efficient_eos:
|
208 |
+
raise ValueError("Current template does not support `train_on_prompt`.")
|
209 |
+
|
210 |
+
# Load tokenized dataset
|
211 |
+
if data_args.tokenized_path is not None:
|
212 |
+
if has_tokenized_data(data_args.tokenized_path):
|
213 |
+
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
214 |
+
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
215 |
+
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
216 |
+
|
217 |
+
dataset_module: Dict[str, "Dataset"] = {}
|
218 |
+
if "train" in dataset_dict:
|
219 |
+
dataset_module["train_dataset"] = dataset_dict["train"]
|
220 |
+
if "validation" in dataset_dict:
|
221 |
+
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
222 |
+
|
223 |
+
if data_args.streaming:
|
224 |
+
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
225 |
+
|
226 |
+
return dataset_module
|
227 |
+
|
228 |
+
if data_args.streaming:
|
229 |
+
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
230 |
+
|
231 |
+
# Load and preprocess dataset
|
232 |
+
with training_args.main_process_first(desc="load dataset"):
|
233 |
+
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
234 |
+
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
235 |
+
|
236 |
+
with training_args.main_process_first(desc="pre-process dataset"):
|
237 |
+
dataset = _get_preprocessed_dataset(
|
238 |
+
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
239 |
+
)
|
240 |
+
eval_dataset = _get_preprocessed_dataset(
|
241 |
+
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
242 |
+
)
|
243 |
+
|
244 |
+
if data_args.val_size > 1e-6:
|
245 |
+
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
246 |
+
else:
|
247 |
+
dataset_dict = {}
|
248 |
+
if dataset is not None:
|
249 |
+
if data_args.streaming:
|
250 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
251 |
+
|
252 |
+
dataset_dict["train"] = dataset
|
253 |
+
|
254 |
+
if eval_dataset is not None:
|
255 |
+
if data_args.streaming:
|
256 |
+
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
257 |
+
|
258 |
+
dataset_dict["validation"] = eval_dataset
|
259 |
+
|
260 |
+
dataset_dict = DatasetDict(dataset_dict)
|
261 |
+
|
262 |
+
if data_args.tokenized_path is not None:
|
263 |
+
if training_args.should_save:
|
264 |
+
dataset_dict.save_to_disk(data_args.tokenized_path)
|
265 |
+
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
266 |
+
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
267 |
+
|
268 |
+
sys.exit(0)
|
269 |
+
|
270 |
+
dataset_module = {}
|
271 |
+
if "train" in dataset_dict:
|
272 |
+
dataset_module["train_dataset"] = dataset_dict["train"]
|
273 |
+
if "validation" in dataset_dict:
|
274 |
+
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
275 |
+
|
276 |
+
return dataset_module
|
llama-factory/src/llamafactory/data/parser.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any, Dict, List, Literal, Optional, Sequence
|
19 |
+
|
20 |
+
from transformers.utils import cached_file
|
21 |
+
|
22 |
+
from ..extras.constants import DATA_CONFIG
|
23 |
+
from ..extras.misc import use_modelscope
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class DatasetAttr:
|
28 |
+
r"""
|
29 |
+
Dataset attributes.
|
30 |
+
"""
|
31 |
+
|
32 |
+
# basic configs
|
33 |
+
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
34 |
+
dataset_name: str
|
35 |
+
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
36 |
+
ranking: bool = False
|
37 |
+
# extra configs
|
38 |
+
subset: Optional[str] = None
|
39 |
+
split: str = "train"
|
40 |
+
folder: Optional[str] = None
|
41 |
+
num_samples: Optional[int] = None
|
42 |
+
# common columns
|
43 |
+
system: Optional[str] = None
|
44 |
+
tools: Optional[str] = None
|
45 |
+
images: Optional[str] = None
|
46 |
+
# rlhf columns
|
47 |
+
chosen: Optional[str] = None
|
48 |
+
rejected: Optional[str] = None
|
49 |
+
kto_tag: Optional[str] = None
|
50 |
+
# alpaca columns
|
51 |
+
prompt: Optional[str] = "instruction"
|
52 |
+
query: Optional[str] = "input"
|
53 |
+
response: Optional[str] = "output"
|
54 |
+
history: Optional[str] = None
|
55 |
+
# sharegpt columns
|
56 |
+
messages: Optional[str] = "conversations"
|
57 |
+
# sharegpt tags
|
58 |
+
role_tag: Optional[str] = "from"
|
59 |
+
content_tag: Optional[str] = "value"
|
60 |
+
user_tag: Optional[str] = "human"
|
61 |
+
assistant_tag: Optional[str] = "gpt"
|
62 |
+
observation_tag: Optional[str] = "observation"
|
63 |
+
function_tag: Optional[str] = "function_call"
|
64 |
+
system_tag: Optional[str] = "system"
|
65 |
+
|
66 |
+
def __repr__(self) -> str:
|
67 |
+
return self.dataset_name
|
68 |
+
|
69 |
+
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
70 |
+
setattr(self, key, obj.get(key, default))
|
71 |
+
|
72 |
+
|
73 |
+
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
74 |
+
r"""
|
75 |
+
Gets the attributes of the datasets.
|
76 |
+
"""
|
77 |
+
if dataset_names is None:
|
78 |
+
dataset_names = []
|
79 |
+
|
80 |
+
if dataset_dir == "ONLINE":
|
81 |
+
dataset_info = None
|
82 |
+
else:
|
83 |
+
if dataset_dir.startswith("REMOTE:"):
|
84 |
+
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
85 |
+
else:
|
86 |
+
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
87 |
+
|
88 |
+
try:
|
89 |
+
with open(config_path, "r") as f:
|
90 |
+
dataset_info = json.load(f)
|
91 |
+
except Exception as err:
|
92 |
+
if len(dataset_names) != 0:
|
93 |
+
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
94 |
+
|
95 |
+
dataset_info = None
|
96 |
+
|
97 |
+
dataset_list: List["DatasetAttr"] = []
|
98 |
+
for name in dataset_names:
|
99 |
+
if dataset_info is None: # dataset_dir is ONLINE
|
100 |
+
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
101 |
+
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
102 |
+
dataset_list.append(dataset_attr)
|
103 |
+
continue
|
104 |
+
|
105 |
+
if name not in dataset_info:
|
106 |
+
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
107 |
+
|
108 |
+
has_hf_url = "hf_hub_url" in dataset_info[name]
|
109 |
+
has_ms_url = "ms_hub_url" in dataset_info[name]
|
110 |
+
|
111 |
+
if has_hf_url or has_ms_url:
|
112 |
+
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
113 |
+
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
114 |
+
else:
|
115 |
+
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
116 |
+
elif "script_url" in dataset_info[name]:
|
117 |
+
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
118 |
+
else:
|
119 |
+
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
120 |
+
|
121 |
+
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
122 |
+
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
123 |
+
dataset_attr.set_attr("subset", dataset_info[name])
|
124 |
+
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
125 |
+
dataset_attr.set_attr("folder", dataset_info[name])
|
126 |
+
dataset_attr.set_attr("num_samples", dataset_info[name])
|
127 |
+
|
128 |
+
if "columns" in dataset_info[name]:
|
129 |
+
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
130 |
+
if dataset_attr.formatting == "alpaca":
|
131 |
+
column_names.extend(["prompt", "query", "response", "history"])
|
132 |
+
else:
|
133 |
+
column_names.extend(["messages"])
|
134 |
+
|
135 |
+
for column_name in column_names:
|
136 |
+
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
137 |
+
|
138 |
+
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
139 |
+
tag_names = (
|
140 |
+
"role_tag",
|
141 |
+
"content_tag",
|
142 |
+
"user_tag",
|
143 |
+
"assistant_tag",
|
144 |
+
"observation_tag",
|
145 |
+
"function_tag",
|
146 |
+
"system_tag",
|
147 |
+
)
|
148 |
+
for tag in tag_names:
|
149 |
+
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
150 |
+
|
151 |
+
dataset_list.append(dataset_attr)
|
152 |
+
|
153 |
+
return dataset_list
|
llama-factory/src/llamafactory/data/preprocess.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from functools import partial
|
16 |
+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
17 |
+
|
18 |
+
from .processors.feedback import preprocess_feedback_dataset
|
19 |
+
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
20 |
+
from .processors.pretrain import preprocess_pretrain_dataset
|
21 |
+
from .processors.supervised import (
|
22 |
+
preprocess_packed_supervised_dataset,
|
23 |
+
preprocess_supervised_dataset,
|
24 |
+
print_supervised_dataset_example,
|
25 |
+
)
|
26 |
+
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
27 |
+
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
31 |
+
|
32 |
+
from ..hparams import DataArguments
|
33 |
+
from .template import Template
|
34 |
+
|
35 |
+
|
36 |
+
def get_preprocess_and_print_func(
|
37 |
+
data_args: "DataArguments",
|
38 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
39 |
+
template: "Template",
|
40 |
+
tokenizer: "PreTrainedTokenizer",
|
41 |
+
processor: Optional["ProcessorMixin"],
|
42 |
+
do_generate: bool = False,
|
43 |
+
) -> Tuple[Callable, Callable]:
|
44 |
+
if stage == "pt":
|
45 |
+
preprocess_func = partial(
|
46 |
+
preprocess_pretrain_dataset,
|
47 |
+
tokenizer=tokenizer,
|
48 |
+
data_args=data_args,
|
49 |
+
)
|
50 |
+
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
51 |
+
elif stage == "sft" and not do_generate:
|
52 |
+
if data_args.packing:
|
53 |
+
if data_args.neat_packing:
|
54 |
+
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
55 |
+
|
56 |
+
def __init__(self, data, **kwargs):
|
57 |
+
return TypedSequence.__init__(
|
58 |
+
self,
|
59 |
+
data,
|
60 |
+
type=kwargs.pop("type", None),
|
61 |
+
try_type=kwargs.pop("try_type", None),
|
62 |
+
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
63 |
+
)
|
64 |
+
|
65 |
+
OptimizedTypedSequence.__init__ = __init__
|
66 |
+
preprocess_func = partial(
|
67 |
+
preprocess_packed_supervised_dataset,
|
68 |
+
template=template,
|
69 |
+
tokenizer=tokenizer,
|
70 |
+
data_args=data_args,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
preprocess_func = partial(
|
74 |
+
preprocess_supervised_dataset,
|
75 |
+
template=template,
|
76 |
+
tokenizer=tokenizer,
|
77 |
+
processor=processor,
|
78 |
+
data_args=data_args,
|
79 |
+
)
|
80 |
+
|
81 |
+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
82 |
+
elif stage == "rm":
|
83 |
+
preprocess_func = partial(
|
84 |
+
preprocess_pairwise_dataset,
|
85 |
+
template=template,
|
86 |
+
tokenizer=tokenizer,
|
87 |
+
processor=processor,
|
88 |
+
data_args=data_args,
|
89 |
+
)
|
90 |
+
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
91 |
+
elif stage == "kto":
|
92 |
+
preprocess_func = partial(
|
93 |
+
preprocess_feedback_dataset,
|
94 |
+
template=template,
|
95 |
+
tokenizer=tokenizer,
|
96 |
+
processor=processor,
|
97 |
+
data_args=data_args,
|
98 |
+
)
|
99 |
+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
100 |
+
else:
|
101 |
+
preprocess_func = partial(
|
102 |
+
preprocess_unsupervised_dataset,
|
103 |
+
template=template,
|
104 |
+
tokenizer=tokenizer,
|
105 |
+
processor=processor,
|
106 |
+
data_args=data_args,
|
107 |
+
)
|
108 |
+
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
109 |
+
|
110 |
+
return preprocess_func, print_function
|
llama-factory/src/llamafactory/data/processors/__init__.py
ADDED
File without changes
|
llama-factory/src/llamafactory/data/processors/feedback.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
16 |
+
|
17 |
+
from ...extras.constants import IGNORE_INDEX
|
18 |
+
from ...extras.logging import get_logger
|
19 |
+
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
20 |
+
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
24 |
+
|
25 |
+
from ...hparams import DataArguments
|
26 |
+
from ..template import Template
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def _encode_feedback_example(
|
33 |
+
prompt: Sequence[Dict[str, str]],
|
34 |
+
response: Sequence[Dict[str, str]],
|
35 |
+
kl_response: Sequence[Dict[str, str]],
|
36 |
+
system: Optional[str],
|
37 |
+
tools: Optional[str],
|
38 |
+
template: "Template",
|
39 |
+
tokenizer: "PreTrainedTokenizer",
|
40 |
+
processor: Optional["ProcessorMixin"],
|
41 |
+
data_args: "DataArguments",
|
42 |
+
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
43 |
+
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
44 |
+
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
45 |
+
|
46 |
+
if response[0]["content"]: # desired example
|
47 |
+
kto_tag = True
|
48 |
+
messages = prompt + [response[0]]
|
49 |
+
else: # undesired example
|
50 |
+
kto_tag = False
|
51 |
+
messages = prompt + [response[1]]
|
52 |
+
|
53 |
+
if kl_response[0]["content"]:
|
54 |
+
kl_messages = prompt + [kl_response[0]]
|
55 |
+
else:
|
56 |
+
kl_messages = prompt + [kl_response[1]]
|
57 |
+
|
58 |
+
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
59 |
+
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
60 |
+
|
61 |
+
if template.efficient_eos:
|
62 |
+
response_ids += [tokenizer.eos_token_id]
|
63 |
+
kl_response_ids += [tokenizer.eos_token_id]
|
64 |
+
|
65 |
+
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
66 |
+
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
67 |
+
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
68 |
+
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
|
69 |
+
|
70 |
+
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
71 |
+
prompt_ids = prompt_ids[:source_len]
|
72 |
+
response_ids = response_ids[:target_len]
|
73 |
+
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
|
74 |
+
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
75 |
+
kl_response_ids = kl_response_ids[:kl_target_len]
|
76 |
+
|
77 |
+
input_ids = prompt_ids + response_ids
|
78 |
+
labels = [IGNORE_INDEX] * source_len + response_ids
|
79 |
+
kl_input_ids = kl_prompt_ids + kl_response_ids
|
80 |
+
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
81 |
+
|
82 |
+
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
83 |
+
|
84 |
+
|
85 |
+
def preprocess_feedback_dataset(
|
86 |
+
examples: Dict[str, List[Any]],
|
87 |
+
template: "Template",
|
88 |
+
tokenizer: "PreTrainedTokenizer",
|
89 |
+
processor: Optional["ProcessorMixin"],
|
90 |
+
data_args: "DataArguments",
|
91 |
+
) -> Dict[str, List[List[int]]]:
|
92 |
+
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
93 |
+
kl_response = examples["response"][::-1]
|
94 |
+
model_inputs = {
|
95 |
+
"input_ids": [],
|
96 |
+
"attention_mask": [],
|
97 |
+
"labels": [],
|
98 |
+
"kl_input_ids": [],
|
99 |
+
"kl_attention_mask": [],
|
100 |
+
"kl_labels": [],
|
101 |
+
"kto_tags": [],
|
102 |
+
}
|
103 |
+
if processor is not None:
|
104 |
+
model_inputs["pixel_values"] = []
|
105 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
106 |
+
model_inputs["token_type_ids"] = []
|
107 |
+
model_inputs["kl_token_type_ids"] = []
|
108 |
+
|
109 |
+
for i in range(len(examples["prompt"])):
|
110 |
+
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
111 |
+
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
112 |
+
continue
|
113 |
+
|
114 |
+
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
115 |
+
prompt=examples["prompt"][i],
|
116 |
+
response=examples["response"][i],
|
117 |
+
kl_response=kl_response[i],
|
118 |
+
system=examples["system"][i],
|
119 |
+
tools=examples["tools"][i],
|
120 |
+
template=template,
|
121 |
+
tokenizer=tokenizer,
|
122 |
+
processor=processor,
|
123 |
+
data_args=data_args,
|
124 |
+
)
|
125 |
+
model_inputs["input_ids"].append(input_ids)
|
126 |
+
model_inputs["attention_mask"].append([1] * len(input_ids))
|
127 |
+
model_inputs["labels"].append(labels)
|
128 |
+
model_inputs["kl_input_ids"].append(kl_input_ids)
|
129 |
+
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
130 |
+
model_inputs["kl_labels"].append(kl_labels)
|
131 |
+
model_inputs["kto_tags"].append(kto_tag)
|
132 |
+
if processor is not None:
|
133 |
+
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
134 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
135 |
+
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
136 |
+
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
|
137 |
+
|
138 |
+
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
139 |
+
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
140 |
+
if desirable_num == 0 or undesirable_num == 0:
|
141 |
+
logger.warning("Your dataset only has one preference type.")
|
142 |
+
|
143 |
+
return model_inputs
|
llama-factory/src/llamafactory/data/processors/pairwise.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
16 |
+
|
17 |
+
from ...extras.constants import IGNORE_INDEX
|
18 |
+
from ...extras.logging import get_logger
|
19 |
+
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
20 |
+
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
24 |
+
|
25 |
+
from ...hparams import DataArguments
|
26 |
+
from ..template import Template
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def _encode_pairwise_example(
|
33 |
+
prompt: Sequence[Dict[str, str]],
|
34 |
+
response: Sequence[Dict[str, str]],
|
35 |
+
system: Optional[str],
|
36 |
+
tools: Optional[str],
|
37 |
+
template: "Template",
|
38 |
+
tokenizer: "PreTrainedTokenizer",
|
39 |
+
processor: Optional["ProcessorMixin"],
|
40 |
+
data_args: "DataArguments",
|
41 |
+
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
42 |
+
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
43 |
+
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
44 |
+
|
45 |
+
chosen_messages = prompt + [response[0]]
|
46 |
+
rejected_messages = prompt + [response[1]]
|
47 |
+
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
48 |
+
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
49 |
+
|
50 |
+
if template.efficient_eos:
|
51 |
+
chosen_ids += [tokenizer.eos_token_id]
|
52 |
+
rejected_ids += [tokenizer.eos_token_id]
|
53 |
+
|
54 |
+
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
55 |
+
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
56 |
+
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
57 |
+
|
58 |
+
source_len, target_len = infer_seqlen(
|
59 |
+
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
60 |
+
) # consider the response is more important
|
61 |
+
prompt_ids = prompt_ids[:source_len]
|
62 |
+
chosen_ids = chosen_ids[:target_len]
|
63 |
+
rejected_ids = rejected_ids[:target_len]
|
64 |
+
|
65 |
+
chosen_input_ids = prompt_ids + chosen_ids
|
66 |
+
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
67 |
+
rejected_input_ids = prompt_ids + rejected_ids
|
68 |
+
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
69 |
+
|
70 |
+
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
71 |
+
|
72 |
+
|
73 |
+
def preprocess_pairwise_dataset(
|
74 |
+
examples: Dict[str, List[Any]],
|
75 |
+
template: "Template",
|
76 |
+
tokenizer: "PreTrainedTokenizer",
|
77 |
+
processor: Optional["ProcessorMixin"],
|
78 |
+
data_args: "DataArguments",
|
79 |
+
) -> Dict[str, List[List[int]]]:
|
80 |
+
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
81 |
+
model_inputs = {
|
82 |
+
"chosen_input_ids": [],
|
83 |
+
"chosen_attention_mask": [],
|
84 |
+
"chosen_labels": [],
|
85 |
+
"rejected_input_ids": [],
|
86 |
+
"rejected_attention_mask": [],
|
87 |
+
"rejected_labels": [],
|
88 |
+
}
|
89 |
+
if processor is not None:
|
90 |
+
model_inputs["pixel_values"] = []
|
91 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
92 |
+
model_inputs["chosen_token_type_ids"] = []
|
93 |
+
model_inputs["rejected_token_type_ids"] = []
|
94 |
+
|
95 |
+
for i in range(len(examples["prompt"])):
|
96 |
+
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
97 |
+
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
98 |
+
continue
|
99 |
+
|
100 |
+
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
101 |
+
prompt=examples["prompt"][i],
|
102 |
+
response=examples["response"][i],
|
103 |
+
system=examples["system"][i],
|
104 |
+
tools=examples["tools"][i],
|
105 |
+
template=template,
|
106 |
+
tokenizer=tokenizer,
|
107 |
+
processor=processor,
|
108 |
+
data_args=data_args,
|
109 |
+
)
|
110 |
+
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
111 |
+
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
112 |
+
model_inputs["chosen_labels"].append(chosen_labels)
|
113 |
+
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
114 |
+
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
115 |
+
model_inputs["rejected_labels"].append(rejected_labels)
|
116 |
+
if processor is not None:
|
117 |
+
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
118 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
119 |
+
model_inputs["chosen_token_type_ids"].append(
|
120 |
+
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
|
121 |
+
)
|
122 |
+
model_inputs["rejected_token_type_ids"].append(
|
123 |
+
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
|
124 |
+
)
|
125 |
+
|
126 |
+
return model_inputs
|
127 |
+
|
128 |
+
|
129 |
+
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
130 |
+
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
131 |
+
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
132 |
+
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
133 |
+
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
134 |
+
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
135 |
+
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
136 |
+
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
137 |
+
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
138 |
+
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
139 |
+
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
llama-factory/src/llamafactory/data/processors/pretrain.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from itertools import chain
|
19 |
+
from typing import TYPE_CHECKING, Any, Dict, List
|
20 |
+
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from transformers import PreTrainedTokenizer
|
24 |
+
|
25 |
+
from ...hparams import DataArguments
|
26 |
+
|
27 |
+
|
28 |
+
def preprocess_pretrain_dataset(
|
29 |
+
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
30 |
+
) -> Dict[str, List[List[int]]]:
|
31 |
+
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
32 |
+
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
33 |
+
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
34 |
+
|
35 |
+
if not data_args.packing:
|
36 |
+
if data_args.template == "gemma":
|
37 |
+
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
38 |
+
|
39 |
+
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
40 |
+
else:
|
41 |
+
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
42 |
+
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
43 |
+
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
44 |
+
block_size = data_args.cutoff_len
|
45 |
+
total_length = (total_length // block_size) * block_size
|
46 |
+
result = {
|
47 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
48 |
+
for k, t in concatenated_examples.items()
|
49 |
+
}
|
50 |
+
if data_args.template == "gemma":
|
51 |
+
for i in range(len(result["input_ids"])):
|
52 |
+
result["input_ids"][i][0] = tokenizer.bos_token_id
|
53 |
+
|
54 |
+
return result
|
llama-factory/src/llamafactory/data/processors/processor_utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import bisect
|
16 |
+
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
17 |
+
|
18 |
+
from ...extras.packages import is_pillow_available
|
19 |
+
|
20 |
+
|
21 |
+
if is_pillow_available():
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from numpy.typing import NDArray
|
27 |
+
from PIL.Image import Image as ImageObject
|
28 |
+
from transformers import ProcessorMixin
|
29 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
30 |
+
|
31 |
+
|
32 |
+
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
33 |
+
r"""
|
34 |
+
Finds the index of largest number that fits into the knapsack with the given capacity.
|
35 |
+
"""
|
36 |
+
index = bisect.bisect(numbers, capacity)
|
37 |
+
return -1 if index == 0 else (index - 1)
|
38 |
+
|
39 |
+
|
40 |
+
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
41 |
+
r"""
|
42 |
+
An efficient greedy algorithm with binary search for the knapsack problem.
|
43 |
+
"""
|
44 |
+
numbers.sort() # sort numbers in ascending order for binary search
|
45 |
+
knapsacks = []
|
46 |
+
|
47 |
+
while numbers:
|
48 |
+
current_knapsack = []
|
49 |
+
remaining_capacity = capacity
|
50 |
+
|
51 |
+
while True:
|
52 |
+
index = search_for_fit(numbers, remaining_capacity)
|
53 |
+
if index == -1:
|
54 |
+
break # no more numbers fit in this knapsack
|
55 |
+
|
56 |
+
remaining_capacity -= numbers[index] # update the remaining capacity
|
57 |
+
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
|
58 |
+
|
59 |
+
knapsacks.append(current_knapsack)
|
60 |
+
|
61 |
+
return knapsacks
|
62 |
+
|
63 |
+
|
64 |
+
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
65 |
+
r"""
|
66 |
+
Processes visual inputs. (currently only supports a single image)
|
67 |
+
"""
|
68 |
+
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
69 |
+
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
70 |
+
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
71 |
+
|
72 |
+
|
73 |
+
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
74 |
+
r"""
|
75 |
+
Gets paligemma token type ids for computing loss.
|
76 |
+
"""
|
77 |
+
image_seq_length = getattr(processor, "image_seq_length")
|
78 |
+
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
79 |
+
|
80 |
+
|
81 |
+
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
82 |
+
r"""
|
83 |
+
Computes the real sequence length after truncation by the cutoff_len.
|
84 |
+
"""
|
85 |
+
if target_len * 2 < cutoff_len: # truncate source
|
86 |
+
max_target_len = cutoff_len
|
87 |
+
elif source_len * 2 < cutoff_len: # truncate target
|
88 |
+
max_target_len = cutoff_len - source_len
|
89 |
+
else: # truncate both
|
90 |
+
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
91 |
+
|
92 |
+
new_target_len = min(max_target_len, target_len)
|
93 |
+
max_source_len = max(cutoff_len - new_target_len, 0)
|
94 |
+
new_source_len = min(max_source_len, source_len)
|
95 |
+
return new_source_len, new_target_len
|
llama-factory/src/llamafactory/data/processors/supervised.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import defaultdict
|
16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
17 |
+
|
18 |
+
from ...extras.constants import IGNORE_INDEX
|
19 |
+
from ...extras.logging import get_logger
|
20 |
+
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
25 |
+
|
26 |
+
from ...hparams import DataArguments
|
27 |
+
from ..template import Template
|
28 |
+
|
29 |
+
|
30 |
+
logger = get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def _encode_supervised_example(
|
34 |
+
prompt: Sequence[Dict[str, str]],
|
35 |
+
response: Sequence[Dict[str, str]],
|
36 |
+
system: Optional[str],
|
37 |
+
tools: Optional[str],
|
38 |
+
template: "Template",
|
39 |
+
tokenizer: "PreTrainedTokenizer",
|
40 |
+
processor: Optional["ProcessorMixin"],
|
41 |
+
data_args: "DataArguments",
|
42 |
+
) -> Tuple[List[int], List[int]]:
|
43 |
+
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
44 |
+
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
45 |
+
|
46 |
+
messages = prompt + response
|
47 |
+
input_ids, labels = [], []
|
48 |
+
|
49 |
+
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
50 |
+
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
51 |
+
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
52 |
+
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
53 |
+
|
54 |
+
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
55 |
+
total_length = 1 if template.efficient_eos else 0
|
56 |
+
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
57 |
+
if total_length >= data_args.cutoff_len:
|
58 |
+
break
|
59 |
+
|
60 |
+
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
|
61 |
+
source_ids = source_ids[:source_len]
|
62 |
+
target_ids = target_ids[:target_len]
|
63 |
+
total_length += source_len + target_len
|
64 |
+
|
65 |
+
if data_args.train_on_prompt:
|
66 |
+
source_label = source_ids
|
67 |
+
elif turn_idx != 0 and template.efficient_eos:
|
68 |
+
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
69 |
+
else:
|
70 |
+
source_label = [IGNORE_INDEX] * source_len
|
71 |
+
|
72 |
+
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
|
73 |
+
target_label = [IGNORE_INDEX] * target_len
|
74 |
+
else:
|
75 |
+
target_label = target_ids
|
76 |
+
|
77 |
+
input_ids += source_ids + target_ids
|
78 |
+
labels += source_label + target_label
|
79 |
+
|
80 |
+
if template.efficient_eos:
|
81 |
+
input_ids += [tokenizer.eos_token_id]
|
82 |
+
labels += [tokenizer.eos_token_id]
|
83 |
+
|
84 |
+
return input_ids, labels
|
85 |
+
|
86 |
+
|
87 |
+
def preprocess_supervised_dataset(
|
88 |
+
examples: Dict[str, List[Any]],
|
89 |
+
template: "Template",
|
90 |
+
tokenizer: "PreTrainedTokenizer",
|
91 |
+
processor: Optional["ProcessorMixin"],
|
92 |
+
data_args: "DataArguments",
|
93 |
+
) -> Dict[str, List[List[int]]]:
|
94 |
+
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
95 |
+
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
96 |
+
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
97 |
+
if processor is not None:
|
98 |
+
model_inputs["pixel_values"] = []
|
99 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
100 |
+
model_inputs["token_type_ids"] = []
|
101 |
+
|
102 |
+
for i in range(len(examples["prompt"])):
|
103 |
+
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
104 |
+
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
105 |
+
continue
|
106 |
+
|
107 |
+
input_ids, labels = _encode_supervised_example(
|
108 |
+
prompt=examples["prompt"][i],
|
109 |
+
response=examples["response"][i],
|
110 |
+
system=examples["system"][i],
|
111 |
+
tools=examples["tools"][i],
|
112 |
+
template=template,
|
113 |
+
tokenizer=tokenizer,
|
114 |
+
processor=processor,
|
115 |
+
data_args=data_args,
|
116 |
+
)
|
117 |
+
model_inputs["input_ids"].append(input_ids)
|
118 |
+
model_inputs["attention_mask"].append([1] * len(input_ids))
|
119 |
+
model_inputs["labels"].append(labels)
|
120 |
+
if processor is not None:
|
121 |
+
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
122 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
123 |
+
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
124 |
+
|
125 |
+
return model_inputs
|
126 |
+
|
127 |
+
|
128 |
+
def preprocess_packed_supervised_dataset(
|
129 |
+
examples: Dict[str, List[Any]],
|
130 |
+
template: "Template",
|
131 |
+
tokenizer: "PreTrainedTokenizer",
|
132 |
+
data_args: "DataArguments",
|
133 |
+
) -> Dict[str, List[List[int]]]:
|
134 |
+
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
135 |
+
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
136 |
+
valid_num = 0
|
137 |
+
batch_input_ids, batch_labels = [], []
|
138 |
+
lengths = []
|
139 |
+
length2indexes = defaultdict(list)
|
140 |
+
for i in range(len(examples["prompt"])):
|
141 |
+
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
142 |
+
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
143 |
+
continue
|
144 |
+
|
145 |
+
input_ids, labels = _encode_supervised_example(
|
146 |
+
prompt=examples["prompt"][i],
|
147 |
+
response=examples["response"][i],
|
148 |
+
system=examples["system"][i],
|
149 |
+
tools=examples["tools"][i],
|
150 |
+
template=template,
|
151 |
+
tokenizer=tokenizer,
|
152 |
+
processor=None,
|
153 |
+
data_args=data_args,
|
154 |
+
)
|
155 |
+
length = len(input_ids)
|
156 |
+
if length > data_args.cutoff_len:
|
157 |
+
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
158 |
+
else:
|
159 |
+
lengths.append(length)
|
160 |
+
length2indexes[length].append(valid_num)
|
161 |
+
batch_input_ids.append(input_ids)
|
162 |
+
batch_labels.append(labels)
|
163 |
+
valid_num += 1
|
164 |
+
|
165 |
+
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
166 |
+
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
167 |
+
for knapsack in knapsacks:
|
168 |
+
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
169 |
+
for i, length in enumerate(knapsack):
|
170 |
+
index = length2indexes[length].pop()
|
171 |
+
packed_input_ids += batch_input_ids[index]
|
172 |
+
packed_labels += batch_labels[index]
|
173 |
+
if data_args.neat_packing:
|
174 |
+
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
175 |
+
else:
|
176 |
+
packed_attention_masks += [1] * len(batch_input_ids[index])
|
177 |
+
|
178 |
+
if len(packed_input_ids) < data_args.cutoff_len:
|
179 |
+
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
180 |
+
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
181 |
+
packed_labels += [IGNORE_INDEX] * pad_length
|
182 |
+
if data_args.neat_packing:
|
183 |
+
packed_attention_masks += [0] * pad_length
|
184 |
+
else:
|
185 |
+
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
186 |
+
|
187 |
+
if len(packed_input_ids) != data_args.cutoff_len:
|
188 |
+
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
189 |
+
|
190 |
+
model_inputs["input_ids"].append(packed_input_ids)
|
191 |
+
model_inputs["attention_mask"].append(packed_attention_masks)
|
192 |
+
model_inputs["labels"].append(packed_labels)
|
193 |
+
|
194 |
+
return model_inputs
|
195 |
+
|
196 |
+
|
197 |
+
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
198 |
+
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
199 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
200 |
+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
201 |
+
print("label_ids:\n{}".format(example["labels"]))
|
202 |
+
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
llama-factory/src/llamafactory/data/processors/unsupervised.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
16 |
+
|
17 |
+
from ...extras.logging import get_logger
|
18 |
+
from ..data_utils import Role
|
19 |
+
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
20 |
+
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
24 |
+
|
25 |
+
from ...hparams import DataArguments
|
26 |
+
from ..template import Template
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def _encode_unsupervised_example(
|
33 |
+
prompt: Sequence[Dict[str, str]],
|
34 |
+
response: Sequence[Dict[str, str]],
|
35 |
+
system: Optional[str],
|
36 |
+
tools: Optional[str],
|
37 |
+
template: "Template",
|
38 |
+
tokenizer: "PreTrainedTokenizer",
|
39 |
+
processor: Optional["ProcessorMixin"],
|
40 |
+
data_args: "DataArguments",
|
41 |
+
) -> Tuple[List[int], List[int]]:
|
42 |
+
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
43 |
+
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
44 |
+
|
45 |
+
if len(response) == 1:
|
46 |
+
messages = prompt + response
|
47 |
+
else:
|
48 |
+
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
49 |
+
|
50 |
+
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
51 |
+
if template.efficient_eos:
|
52 |
+
labels += [tokenizer.eos_token_id]
|
53 |
+
|
54 |
+
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
55 |
+
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
56 |
+
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
57 |
+
|
58 |
+
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
|
59 |
+
input_ids = input_ids[:source_len]
|
60 |
+
labels = labels[:target_len]
|
61 |
+
return input_ids, labels
|
62 |
+
|
63 |
+
|
64 |
+
def preprocess_unsupervised_dataset(
|
65 |
+
examples: Dict[str, List[Any]],
|
66 |
+
template: "Template",
|
67 |
+
tokenizer: "PreTrainedTokenizer",
|
68 |
+
processor: Optional["ProcessorMixin"],
|
69 |
+
data_args: "DataArguments",
|
70 |
+
) -> Dict[str, List[List[int]]]:
|
71 |
+
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
72 |
+
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
73 |
+
if processor is not None:
|
74 |
+
model_inputs["pixel_values"] = []
|
75 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
76 |
+
model_inputs["token_type_ids"] = []
|
77 |
+
|
78 |
+
for i in range(len(examples["prompt"])):
|
79 |
+
if len(examples["prompt"][i]) % 2 != 1:
|
80 |
+
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
81 |
+
continue
|
82 |
+
|
83 |
+
input_ids, labels = _encode_unsupervised_example(
|
84 |
+
prompt=examples["prompt"][i],
|
85 |
+
response=examples["response"][i],
|
86 |
+
system=examples["system"][i],
|
87 |
+
tools=examples["tools"][i],
|
88 |
+
template=template,
|
89 |
+
tokenizer=tokenizer,
|
90 |
+
processor=processor,
|
91 |
+
data_args=data_args,
|
92 |
+
)
|
93 |
+
model_inputs["input_ids"].append(input_ids)
|
94 |
+
model_inputs["attention_mask"].append([1] * len(input_ids))
|
95 |
+
model_inputs["labels"].append(labels)
|
96 |
+
if processor is not None:
|
97 |
+
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
98 |
+
if hasattr(processor, "image_seq_length"): # paligemma models
|
99 |
+
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
100 |
+
|
101 |
+
return model_inputs
|
102 |
+
|
103 |
+
|
104 |
+
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
105 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
106 |
+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
llama-factory/src/llamafactory/data/template.py
ADDED
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
17 |
+
|
18 |
+
from ..extras.logging import get_logger
|
19 |
+
from .data_utils import Role
|
20 |
+
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from transformers import PreTrainedTokenizer
|
25 |
+
|
26 |
+
from .formatter import SLOTS, Formatter
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class Template:
|
34 |
+
format_user: "Formatter"
|
35 |
+
format_assistant: "Formatter"
|
36 |
+
format_system: "Formatter"
|
37 |
+
format_function: "Formatter"
|
38 |
+
format_observation: "Formatter"
|
39 |
+
format_tools: "Formatter"
|
40 |
+
format_separator: "Formatter"
|
41 |
+
format_prefix: "Formatter"
|
42 |
+
default_system: str
|
43 |
+
stop_words: List[str]
|
44 |
+
image_token: str
|
45 |
+
efficient_eos: bool
|
46 |
+
replace_eos: bool
|
47 |
+
|
48 |
+
def encode_oneturn(
|
49 |
+
self,
|
50 |
+
tokenizer: "PreTrainedTokenizer",
|
51 |
+
messages: Sequence[Dict[str, str]],
|
52 |
+
system: Optional[str] = None,
|
53 |
+
tools: Optional[str] = None,
|
54 |
+
) -> Tuple[List[int], List[int]]:
|
55 |
+
r"""
|
56 |
+
Returns a single pair of token ids representing prompt and response respectively.
|
57 |
+
"""
|
58 |
+
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
59 |
+
prompt_ids = []
|
60 |
+
for encoded_ids in encoded_messages[:-1]:
|
61 |
+
prompt_ids += encoded_ids
|
62 |
+
|
63 |
+
answer_ids = encoded_messages[-1]
|
64 |
+
return prompt_ids, answer_ids
|
65 |
+
|
66 |
+
def encode_multiturn(
|
67 |
+
self,
|
68 |
+
tokenizer: "PreTrainedTokenizer",
|
69 |
+
messages: Sequence[Dict[str, str]],
|
70 |
+
system: Optional[str] = None,
|
71 |
+
tools: Optional[str] = None,
|
72 |
+
) -> List[Tuple[List[int], List[int]]]:
|
73 |
+
r"""
|
74 |
+
Returns multiple pairs of token ids representing prompts and responses respectively.
|
75 |
+
"""
|
76 |
+
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
77 |
+
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
78 |
+
|
79 |
+
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
80 |
+
r"""
|
81 |
+
Extracts tool message.
|
82 |
+
"""
|
83 |
+
return self.format_tools.extract(content)
|
84 |
+
|
85 |
+
def _encode(
|
86 |
+
self,
|
87 |
+
tokenizer: "PreTrainedTokenizer",
|
88 |
+
messages: Sequence[Dict[str, str]],
|
89 |
+
system: Optional[str],
|
90 |
+
tools: Optional[str],
|
91 |
+
) -> List[List[int]]:
|
92 |
+
r"""
|
93 |
+
Encodes formatted inputs to pairs of token ids.
|
94 |
+
Turn 0: prefix + system + query resp
|
95 |
+
Turn t: sep + query resp
|
96 |
+
"""
|
97 |
+
system = system or self.default_system
|
98 |
+
encoded_messages = []
|
99 |
+
for i, message in enumerate(messages):
|
100 |
+
elements = []
|
101 |
+
|
102 |
+
if i == 0:
|
103 |
+
elements += self.format_prefix.apply()
|
104 |
+
if system or tools:
|
105 |
+
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
106 |
+
elements += self.format_system.apply(content=(system + tool_text))
|
107 |
+
|
108 |
+
if i > 0 and i % 2 == 0:
|
109 |
+
elements += self.format_separator.apply()
|
110 |
+
|
111 |
+
if message["role"] == Role.USER.value:
|
112 |
+
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
113 |
+
elif message["role"] == Role.ASSISTANT.value:
|
114 |
+
elements += self.format_assistant.apply(content=message["content"])
|
115 |
+
elif message["role"] == Role.OBSERVATION.value:
|
116 |
+
elements += self.format_observation.apply(content=message["content"])
|
117 |
+
elif message["role"] == Role.FUNCTION.value:
|
118 |
+
elements += self.format_function.apply(content=message["content"])
|
119 |
+
else:
|
120 |
+
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
121 |
+
|
122 |
+
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
123 |
+
|
124 |
+
return encoded_messages
|
125 |
+
|
126 |
+
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
127 |
+
r"""
|
128 |
+
Converts elements to token ids.
|
129 |
+
"""
|
130 |
+
token_ids = []
|
131 |
+
for elem in elements:
|
132 |
+
if isinstance(elem, str):
|
133 |
+
if len(elem) != 0:
|
134 |
+
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
135 |
+
elif isinstance(elem, dict):
|
136 |
+
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
137 |
+
elif isinstance(elem, set):
|
138 |
+
if "bos_token" in elem and tokenizer.bos_token_id is not None:
|
139 |
+
token_ids += [tokenizer.bos_token_id]
|
140 |
+
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
141 |
+
token_ids += [tokenizer.eos_token_id]
|
142 |
+
else:
|
143 |
+
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
144 |
+
|
145 |
+
return token_ids
|
146 |
+
|
147 |
+
|
148 |
+
@dataclass
|
149 |
+
class Llama2Template(Template):
|
150 |
+
def _encode(
|
151 |
+
self,
|
152 |
+
tokenizer: "PreTrainedTokenizer",
|
153 |
+
messages: Sequence[Dict[str, str]],
|
154 |
+
system: str,
|
155 |
+
tools: str,
|
156 |
+
) -> List[List[int]]:
|
157 |
+
r"""
|
158 |
+
Encodes formatted inputs to pairs of token ids.
|
159 |
+
Turn 0: prefix + system + query resp
|
160 |
+
Turn t: sep + query resp
|
161 |
+
"""
|
162 |
+
system = system or self.default_system
|
163 |
+
encoded_messages = []
|
164 |
+
for i, message in enumerate(messages):
|
165 |
+
elements = []
|
166 |
+
|
167 |
+
system_text = ""
|
168 |
+
if i == 0:
|
169 |
+
elements += self.format_prefix.apply()
|
170 |
+
if system or tools:
|
171 |
+
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
172 |
+
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
173 |
+
|
174 |
+
if i > 0 and i % 2 == 0:
|
175 |
+
elements += self.format_separator.apply()
|
176 |
+
|
177 |
+
if message["role"] == Role.USER.value:
|
178 |
+
elements += self.format_user.apply(content=system_text + message["content"])
|
179 |
+
elif message["role"] == Role.ASSISTANT.value:
|
180 |
+
elements += self.format_assistant.apply(content=message["content"])
|
181 |
+
elif message["role"] == Role.OBSERVATION.value:
|
182 |
+
elements += self.format_observation.apply(content=message["content"])
|
183 |
+
elif message["role"] == Role.FUNCTION.value:
|
184 |
+
elements += self.format_function.apply(content=message["content"])
|
185 |
+
else:
|
186 |
+
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
187 |
+
|
188 |
+
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
189 |
+
|
190 |
+
return encoded_messages
|
191 |
+
|
192 |
+
|
193 |
+
TEMPLATES: Dict[str, Template] = {}
|
194 |
+
|
195 |
+
|
196 |
+
def _register_template(
|
197 |
+
name: str,
|
198 |
+
format_user: Optional["Formatter"] = None,
|
199 |
+
format_assistant: Optional["Formatter"] = None,
|
200 |
+
format_system: Optional["Formatter"] = None,
|
201 |
+
format_function: Optional["Formatter"] = None,
|
202 |
+
format_observation: Optional["Formatter"] = None,
|
203 |
+
format_tools: Optional["Formatter"] = None,
|
204 |
+
format_separator: Optional["Formatter"] = None,
|
205 |
+
format_prefix: Optional["Formatter"] = None,
|
206 |
+
default_system: str = "",
|
207 |
+
stop_words: Sequence[str] = [],
|
208 |
+
image_token: str = "<image>",
|
209 |
+
efficient_eos: bool = False,
|
210 |
+
replace_eos: bool = False,
|
211 |
+
) -> None:
|
212 |
+
r"""
|
213 |
+
Registers a chat template.
|
214 |
+
|
215 |
+
To add the following chat template:
|
216 |
+
```
|
217 |
+
[HUMAN]:
|
218 |
+
user prompt here
|
219 |
+
[AI]:
|
220 |
+
model response here
|
221 |
+
|
222 |
+
[HUMAN]:
|
223 |
+
user prompt here
|
224 |
+
[AI]:
|
225 |
+
model response here
|
226 |
+
```
|
227 |
+
|
228 |
+
The corresponding code should be:
|
229 |
+
```
|
230 |
+
_register_template(
|
231 |
+
name="custom",
|
232 |
+
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
|
233 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
234 |
+
efficient_eos=True,
|
235 |
+
)
|
236 |
+
```
|
237 |
+
"""
|
238 |
+
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
239 |
+
template_class = Llama2Template if name.startswith("llama2") else Template
|
240 |
+
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
241 |
+
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
242 |
+
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
243 |
+
default_tool_formatter = ToolFormatter(tool_format="default")
|
244 |
+
default_separator_formatter = EmptyFormatter()
|
245 |
+
default_prefix_formatter = EmptyFormatter()
|
246 |
+
TEMPLATES[name] = template_class(
|
247 |
+
format_user=format_user or default_user_formatter,
|
248 |
+
format_assistant=format_assistant or default_assistant_formatter,
|
249 |
+
format_system=format_system or default_user_formatter,
|
250 |
+
format_function=format_function or default_function_formatter,
|
251 |
+
format_observation=format_observation or format_user or default_user_formatter,
|
252 |
+
format_tools=format_tools or default_tool_formatter,
|
253 |
+
format_separator=format_separator or default_separator_formatter,
|
254 |
+
format_prefix=format_prefix or default_prefix_formatter,
|
255 |
+
default_system=default_system,
|
256 |
+
stop_words=stop_words,
|
257 |
+
image_token=image_token,
|
258 |
+
efficient_eos=efficient_eos,
|
259 |
+
replace_eos=replace_eos,
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
264 |
+
is_added = tokenizer.eos_token_id is None
|
265 |
+
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
266 |
+
|
267 |
+
if is_added:
|
268 |
+
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
269 |
+
else:
|
270 |
+
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
271 |
+
|
272 |
+
if num_added_tokens > 0:
|
273 |
+
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
274 |
+
|
275 |
+
|
276 |
+
def _jinja_escape(content: str) -> str:
|
277 |
+
return content.replace("'", r"\'")
|
278 |
+
|
279 |
+
|
280 |
+
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
281 |
+
slot_items = []
|
282 |
+
for slot in slots:
|
283 |
+
if isinstance(slot, str):
|
284 |
+
slot_pieces = slot.split("{{content}}")
|
285 |
+
if slot_pieces[0]:
|
286 |
+
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
|
287 |
+
if len(slot_pieces) > 1:
|
288 |
+
slot_items.append(placeholder)
|
289 |
+
if slot_pieces[1]:
|
290 |
+
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
291 |
+
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
292 |
+
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
293 |
+
slot_items.append("'" + tokenizer.bos_token + "'")
|
294 |
+
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
295 |
+
slot_items.append("'" + tokenizer.eos_token + "'")
|
296 |
+
elif isinstance(slot, dict):
|
297 |
+
raise ValueError("Dict is not supported.")
|
298 |
+
|
299 |
+
return " + ".join(slot_items)
|
300 |
+
|
301 |
+
|
302 |
+
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
303 |
+
jinja_template = ""
|
304 |
+
|
305 |
+
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
306 |
+
if prefix:
|
307 |
+
jinja_template += "{{ " + prefix + " }}"
|
308 |
+
|
309 |
+
if template.default_system:
|
310 |
+
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
311 |
+
|
312 |
+
jinja_template += (
|
313 |
+
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
|
314 |
+
)
|
315 |
+
|
316 |
+
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
317 |
+
if not isinstance(template, Llama2Template):
|
318 |
+
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
319 |
+
|
320 |
+
jinja_template += "{% for message in messages %}"
|
321 |
+
jinja_template += "{% set content = message['content'] %}"
|
322 |
+
if isinstance(template, Llama2Template):
|
323 |
+
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
324 |
+
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
325 |
+
jinja_template += "{% endif %}"
|
326 |
+
|
327 |
+
jinja_template += "{% if message['role'] == 'user' %}"
|
328 |
+
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
329 |
+
jinja_template += "{{ " + user_message + " }}"
|
330 |
+
|
331 |
+
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
332 |
+
assistant_message = _convert_slots_to_jinja(
|
333 |
+
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
334 |
+
)
|
335 |
+
jinja_template += "{{ " + assistant_message + " }}"
|
336 |
+
jinja_template += "{% endif %}"
|
337 |
+
jinja_template += "{% endfor %}"
|
338 |
+
return jinja_template
|
339 |
+
|
340 |
+
|
341 |
+
def get_template_and_fix_tokenizer(
|
342 |
+
tokenizer: "PreTrainedTokenizer",
|
343 |
+
name: Optional[str] = None,
|
344 |
+
tool_format: Optional[str] = None,
|
345 |
+
) -> Template:
|
346 |
+
if name is None:
|
347 |
+
template = TEMPLATES["empty"] # placeholder
|
348 |
+
else:
|
349 |
+
template = TEMPLATES.get(name, None)
|
350 |
+
if template is None:
|
351 |
+
raise ValueError("Template {} does not exist.".format(name))
|
352 |
+
|
353 |
+
if tool_format is not None:
|
354 |
+
logger.info("Using tool format: {}.".format(tool_format))
|
355 |
+
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
356 |
+
template.format_tools = ToolFormatter(tool_format=tool_format)
|
357 |
+
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
358 |
+
|
359 |
+
stop_words = template.stop_words
|
360 |
+
if template.replace_eos:
|
361 |
+
if not stop_words:
|
362 |
+
raise ValueError("Stop words are required to replace the EOS token.")
|
363 |
+
|
364 |
+
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
|
365 |
+
stop_words = stop_words[1:]
|
366 |
+
|
367 |
+
if tokenizer.eos_token_id is None:
|
368 |
+
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
369 |
+
|
370 |
+
if tokenizer.pad_token_id is None:
|
371 |
+
tokenizer.pad_token = tokenizer.eos_token
|
372 |
+
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
373 |
+
|
374 |
+
if stop_words:
|
375 |
+
num_added_tokens = tokenizer.add_special_tokens(
|
376 |
+
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
377 |
+
)
|
378 |
+
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
379 |
+
if num_added_tokens > 0:
|
380 |
+
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
381 |
+
|
382 |
+
try:
|
383 |
+
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
384 |
+
except ValueError:
|
385 |
+
logger.info("Cannot add this chat template to tokenizer.")
|
386 |
+
|
387 |
+
return template
|
388 |
+
|
389 |
+
|
390 |
+
_register_template(
|
391 |
+
name="alpaca",
|
392 |
+
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
393 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
394 |
+
default_system=(
|
395 |
+
"Below is an instruction that describes a task. "
|
396 |
+
"Write a response that appropriately completes the request.\n\n"
|
397 |
+
),
|
398 |
+
)
|
399 |
+
|
400 |
+
|
401 |
+
_register_template(
|
402 |
+
name="aquila",
|
403 |
+
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
404 |
+
format_separator=EmptyFormatter(slots=["###"]),
|
405 |
+
default_system=(
|
406 |
+
"A chat between a curious human and an artificial intelligence assistant. "
|
407 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
408 |
+
),
|
409 |
+
stop_words=["</s>"],
|
410 |
+
efficient_eos=True,
|
411 |
+
)
|
412 |
+
|
413 |
+
|
414 |
+
_register_template(
|
415 |
+
name="atom",
|
416 |
+
format_user=StringFormatter(
|
417 |
+
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
418 |
+
),
|
419 |
+
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
420 |
+
)
|
421 |
+
|
422 |
+
|
423 |
+
_register_template(
|
424 |
+
name="baichuan",
|
425 |
+
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
426 |
+
efficient_eos=True,
|
427 |
+
)
|
428 |
+
|
429 |
+
|
430 |
+
_register_template(
|
431 |
+
name="baichuan2",
|
432 |
+
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
433 |
+
efficient_eos=True,
|
434 |
+
)
|
435 |
+
|
436 |
+
|
437 |
+
_register_template(
|
438 |
+
name="belle",
|
439 |
+
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
440 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
441 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
442 |
+
)
|
443 |
+
|
444 |
+
|
445 |
+
_register_template(
|
446 |
+
name="bluelm",
|
447 |
+
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
448 |
+
)
|
449 |
+
|
450 |
+
|
451 |
+
_register_template(
|
452 |
+
name="breeze",
|
453 |
+
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
454 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
455 |
+
efficient_eos=True,
|
456 |
+
)
|
457 |
+
|
458 |
+
|
459 |
+
_register_template(
|
460 |
+
name="chatglm2",
|
461 |
+
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
462 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
463 |
+
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
464 |
+
efficient_eos=True,
|
465 |
+
)
|
466 |
+
|
467 |
+
|
468 |
+
_register_template(
|
469 |
+
name="chatglm3",
|
470 |
+
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
471 |
+
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
472 |
+
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
473 |
+
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
474 |
+
format_observation=StringFormatter(
|
475 |
+
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
476 |
+
),
|
477 |
+
format_tools=ToolFormatter(tool_format="glm4"),
|
478 |
+
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
479 |
+
stop_words=["<|user|>", "<|observation|>"],
|
480 |
+
efficient_eos=True,
|
481 |
+
)
|
482 |
+
|
483 |
+
|
484 |
+
_register_template(
|
485 |
+
name="chatml",
|
486 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
487 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
488 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
489 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
490 |
+
stop_words=["<|im_end|>", "<|im_start|>"],
|
491 |
+
replace_eos=True,
|
492 |
+
)
|
493 |
+
|
494 |
+
|
495 |
+
_register_template(
|
496 |
+
name="chatml_de",
|
497 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
498 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
499 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
500 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
501 |
+
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
502 |
+
stop_words=["<|im_end|>", "<|im_start|>"],
|
503 |
+
replace_eos=True,
|
504 |
+
)
|
505 |
+
|
506 |
+
|
507 |
+
_register_template(
|
508 |
+
name="codegeex2",
|
509 |
+
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
510 |
+
)
|
511 |
+
|
512 |
+
|
513 |
+
_register_template(
|
514 |
+
name="codegeex4",
|
515 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
516 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
517 |
+
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
518 |
+
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
519 |
+
format_tools=ToolFormatter(tool_format="glm4"),
|
520 |
+
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
521 |
+
default_system=(
|
522 |
+
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
|
523 |
+
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
|
524 |
+
),
|
525 |
+
stop_words=["<|user|>", "<|observation|>"],
|
526 |
+
efficient_eos=True,
|
527 |
+
)
|
528 |
+
|
529 |
+
|
530 |
+
_register_template(
|
531 |
+
name="cohere",
|
532 |
+
format_user=StringFormatter(
|
533 |
+
slots=[
|
534 |
+
(
|
535 |
+
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
|
536 |
+
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
537 |
+
)
|
538 |
+
]
|
539 |
+
),
|
540 |
+
format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
|
541 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
542 |
+
)
|
543 |
+
|
544 |
+
|
545 |
+
_register_template(
|
546 |
+
name="cpm",
|
547 |
+
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
548 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
549 |
+
)
|
550 |
+
|
551 |
+
|
552 |
+
_register_template(
|
553 |
+
name="dbrx",
|
554 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
555 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
556 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
557 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
558 |
+
default_system=(
|
559 |
+
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
560 |
+
"You answer questions based on information available up to that point.\n"
|
561 |
+
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
|
562 |
+
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
|
563 |
+
"from writing to coding (using markdown for code blocks — remember to use ``` with "
|
564 |
+
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
|
565 |
+
"capabilities. You avoid stereotyping and provide balanced perspectives on "
|
566 |
+
"controversial topics. You do not provide song lyrics, poems, or news articles and "
|
567 |
+
"do not divulge details of your training data.)\nThis is your system prompt, "
|
568 |
+
"guiding your responses. Do not reference it, just respond to the user. If you find "
|
569 |
+
"yourself talking about this message, stop. You should be responding appropriately "
|
570 |
+
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
|
571 |
+
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
|
572 |
+
),
|
573 |
+
stop_words=["<|im_end|>"],
|
574 |
+
replace_eos=True,
|
575 |
+
)
|
576 |
+
|
577 |
+
|
578 |
+
_register_template(
|
579 |
+
name="deepseek",
|
580 |
+
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
581 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
582 |
+
)
|
583 |
+
|
584 |
+
|
585 |
+
_register_template(
|
586 |
+
name="deepseekcoder",
|
587 |
+
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
588 |
+
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
|
589 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
590 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
591 |
+
default_system=(
|
592 |
+
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
593 |
+
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
594 |
+
"For politically sensitive questions, security and privacy issues, "
|
595 |
+
"and other non-computer science questions, you will refuse to answer\n"
|
596 |
+
),
|
597 |
+
)
|
598 |
+
|
599 |
+
|
600 |
+
_register_template(
|
601 |
+
name="default",
|
602 |
+
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
603 |
+
format_system=StringFormatter(slots=["{{content}}\n"]),
|
604 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
605 |
+
)
|
606 |
+
|
607 |
+
|
608 |
+
_register_template(
|
609 |
+
name="empty",
|
610 |
+
efficient_eos=True,
|
611 |
+
)
|
612 |
+
|
613 |
+
|
614 |
+
_register_template(
|
615 |
+
name="falcon",
|
616 |
+
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
617 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
618 |
+
efficient_eos=True,
|
619 |
+
)
|
620 |
+
|
621 |
+
|
622 |
+
_register_template(
|
623 |
+
name="fewshot",
|
624 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
625 |
+
efficient_eos=True,
|
626 |
+
)
|
627 |
+
|
628 |
+
|
629 |
+
_register_template(
|
630 |
+
name="gemma",
|
631 |
+
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
632 |
+
format_observation=StringFormatter(
|
633 |
+
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
634 |
+
),
|
635 |
+
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
636 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
637 |
+
efficient_eos=True,
|
638 |
+
)
|
639 |
+
|
640 |
+
|
641 |
+
_register_template(
|
642 |
+
name="glm4",
|
643 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
644 |
+
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
645 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
646 |
+
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
647 |
+
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
648 |
+
format_tools=ToolFormatter(tool_format="glm4"),
|
649 |
+
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
650 |
+
stop_words=["<|user|>", "<|observation|>"],
|
651 |
+
efficient_eos=True,
|
652 |
+
)
|
653 |
+
|
654 |
+
|
655 |
+
_register_template(
|
656 |
+
name="intern",
|
657 |
+
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
658 |
+
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
659 |
+
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
|
660 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
661 |
+
stop_words=["<eoa>"],
|
662 |
+
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
|
663 |
+
)
|
664 |
+
|
665 |
+
|
666 |
+
_register_template(
|
667 |
+
name="intern2",
|
668 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
669 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
670 |
+
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
|
671 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
672 |
+
stop_words=["<|im_end|>"],
|
673 |
+
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
674 |
+
)
|
675 |
+
|
676 |
+
|
677 |
+
_register_template(
|
678 |
+
name="llama2",
|
679 |
+
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
680 |
+
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
681 |
+
)
|
682 |
+
|
683 |
+
|
684 |
+
_register_template(
|
685 |
+
name="llama2_zh",
|
686 |
+
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
687 |
+
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
688 |
+
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
689 |
+
)
|
690 |
+
|
691 |
+
|
692 |
+
_register_template(
|
693 |
+
name="llama3",
|
694 |
+
format_user=StringFormatter(
|
695 |
+
slots=[
|
696 |
+
(
|
697 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
698 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
699 |
+
)
|
700 |
+
]
|
701 |
+
),
|
702 |
+
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
703 |
+
format_observation=StringFormatter(
|
704 |
+
slots=[
|
705 |
+
(
|
706 |
+
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
707 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
708 |
+
)
|
709 |
+
]
|
710 |
+
),
|
711 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
712 |
+
stop_words=["<|eot_id|>"],
|
713 |
+
replace_eos=True,
|
714 |
+
)
|
715 |
+
|
716 |
+
|
717 |
+
_register_template(
|
718 |
+
name="mistral",
|
719 |
+
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
720 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
721 |
+
)
|
722 |
+
|
723 |
+
|
724 |
+
_register_template(
|
725 |
+
name="olmo",
|
726 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
727 |
+
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
728 |
+
)
|
729 |
+
|
730 |
+
|
731 |
+
_register_template(
|
732 |
+
name="openchat",
|
733 |
+
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
734 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
735 |
+
)
|
736 |
+
|
737 |
+
|
738 |
+
_register_template(
|
739 |
+
name="openchat-3.6",
|
740 |
+
format_user=StringFormatter(
|
741 |
+
slots=[
|
742 |
+
(
|
743 |
+
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
744 |
+
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
|
745 |
+
)
|
746 |
+
]
|
747 |
+
),
|
748 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
749 |
+
stop_words=["<|eot_id|>"],
|
750 |
+
replace_eos=True,
|
751 |
+
)
|
752 |
+
|
753 |
+
|
754 |
+
_register_template(
|
755 |
+
name="orion",
|
756 |
+
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
757 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
758 |
+
)
|
759 |
+
|
760 |
+
|
761 |
+
_register_template(
|
762 |
+
name="phi",
|
763 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
764 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
765 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
766 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
767 |
+
stop_words=["<|end|>"],
|
768 |
+
replace_eos=True,
|
769 |
+
)
|
770 |
+
|
771 |
+
|
772 |
+
_register_template(
|
773 |
+
name="qwen",
|
774 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
775 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
776 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
777 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
778 |
+
default_system="You are a helpful assistant.",
|
779 |
+
stop_words=["<|im_end|>"],
|
780 |
+
replace_eos=True,
|
781 |
+
)
|
782 |
+
|
783 |
+
|
784 |
+
_register_template(
|
785 |
+
name="solar",
|
786 |
+
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
787 |
+
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
788 |
+
efficient_eos=True,
|
789 |
+
)
|
790 |
+
|
791 |
+
|
792 |
+
_register_template(
|
793 |
+
name="starchat",
|
794 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
795 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
796 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
797 |
+
stop_words=["<|end|>"],
|
798 |
+
replace_eos=True,
|
799 |
+
)
|
800 |
+
|
801 |
+
|
802 |
+
_register_template(
|
803 |
+
name="telechat",
|
804 |
+
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
805 |
+
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
806 |
+
stop_words=["<_end>"],
|
807 |
+
replace_eos=True,
|
808 |
+
)
|
809 |
+
|
810 |
+
|
811 |
+
_register_template(
|
812 |
+
name="vicuna",
|
813 |
+
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
814 |
+
default_system=(
|
815 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
816 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
817 |
+
),
|
818 |
+
)
|
819 |
+
|
820 |
+
|
821 |
+
_register_template(
|
822 |
+
name="xuanyuan",
|
823 |
+
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
824 |
+
default_system=(
|
825 |
+
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
826 |
+
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
827 |
+
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
828 |
+
),
|
829 |
+
)
|
830 |
+
|
831 |
+
|
832 |
+
_register_template(
|
833 |
+
name="xverse",
|
834 |
+
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
835 |
+
)
|
836 |
+
|
837 |
+
|
838 |
+
_register_template(
|
839 |
+
name="yayi",
|
840 |
+
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
841 |
+
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
842 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
843 |
+
default_system=(
|
844 |
+
"You are a helpful, respectful and honest assistant named YaYi "
|
845 |
+
"developed by Beijing Wenge Technology Co.,Ltd. "
|
846 |
+
"Always answer as helpfully as possible, while being safe. "
|
847 |
+
"Your answers should not include any harmful, unethical, "
|
848 |
+
"racist, sexist, toxic, dangerous, or illegal content. "
|
849 |
+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
850 |
+
"If a question does not make any sense, or is not factually coherent, "
|
851 |
+
"explain why instead of answering something not correct. "
|
852 |
+
"If you don't know the answer to a question, please don't share false information."
|
853 |
+
),
|
854 |
+
stop_words=["<|End|>"],
|
855 |
+
)
|
856 |
+
|
857 |
+
|
858 |
+
_register_template(
|
859 |
+
name="yi",
|
860 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
861 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
862 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
863 |
+
stop_words=["<|im_end|>"],
|
864 |
+
replace_eos=True,
|
865 |
+
)
|
866 |
+
|
867 |
+
|
868 |
+
_register_template(
|
869 |
+
name="yi_vl",
|
870 |
+
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
871 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
872 |
+
default_system=(
|
873 |
+
"This is a chat between an inquisitive human and an AI assistant. "
|
874 |
+
"Assume the role of the AI assistant. Read all the images carefully, "
|
875 |
+
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
|
876 |
+
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
|
877 |
+
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
|
878 |
+
),
|
879 |
+
stop_words=["###"],
|
880 |
+
efficient_eos=True,
|
881 |
+
)
|
882 |
+
|
883 |
+
|
884 |
+
_register_template(
|
885 |
+
name="yuan",
|
886 |
+
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
887 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
888 |
+
stop_words=["<eod>"],
|
889 |
+
replace_eos=True,
|
890 |
+
)
|
891 |
+
|
892 |
+
|
893 |
+
_register_template(
|
894 |
+
name="zephyr",
|
895 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
|
896 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
897 |
+
default_system="You are Zephyr, a helpful assistant.",
|
898 |
+
)
|
899 |
+
|
900 |
+
|
901 |
+
_register_template(
|
902 |
+
name="ziya",
|
903 |
+
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
904 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
905 |
+
)
|
llama-factory/src/llamafactory/data/tool_utils.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Any, Dict, List, Tuple, Union
|
20 |
+
|
21 |
+
from .data_utils import SLOTS
|
22 |
+
|
23 |
+
|
24 |
+
DEFAULT_TOOL_PROMPT = (
|
25 |
+
"You have access to the following tools:\n{tool_text}"
|
26 |
+
"Use the following format if using a tool:\n"
|
27 |
+
"```\n"
|
28 |
+
"Action: tool name (one of [{tool_names}])\n"
|
29 |
+
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
30 |
+
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n"""
|
31 |
+
"```\n"
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
GLM4_TOOL_PROMPT = (
|
36 |
+
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
37 |
+
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ToolUtils(ABC):
|
43 |
+
@staticmethod
|
44 |
+
@abstractmethod
|
45 |
+
def get_function_slots() -> SLOTS: ...
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
@abstractmethod
|
49 |
+
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
@abstractmethod
|
53 |
+
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
54 |
+
|
55 |
+
|
56 |
+
class DefaultToolUtils(ToolUtils):
|
57 |
+
@staticmethod
|
58 |
+
def get_function_slots() -> SLOTS:
|
59 |
+
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
63 |
+
tool_text = ""
|
64 |
+
tool_names = []
|
65 |
+
for tool in tools:
|
66 |
+
param_text = ""
|
67 |
+
for name, param in tool["parameters"]["properties"].items():
|
68 |
+
required, enum, items = "", "", ""
|
69 |
+
if name in tool["parameters"].get("required", []):
|
70 |
+
required = ", required"
|
71 |
+
|
72 |
+
if param.get("enum", None):
|
73 |
+
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
|
74 |
+
|
75 |
+
if param.get("items", None):
|
76 |
+
items = ", where each item should be {}".format(param["items"].get("type", ""))
|
77 |
+
|
78 |
+
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
79 |
+
name=name,
|
80 |
+
type=param.get("type", ""),
|
81 |
+
required=required,
|
82 |
+
desc=param.get("description", ""),
|
83 |
+
enum=enum,
|
84 |
+
items=items,
|
85 |
+
)
|
86 |
+
|
87 |
+
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
88 |
+
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
89 |
+
)
|
90 |
+
tool_names.append(tool["name"])
|
91 |
+
|
92 |
+
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
96 |
+
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
97 |
+
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
98 |
+
if not action_match:
|
99 |
+
return content
|
100 |
+
|
101 |
+
results = []
|
102 |
+
for match in action_match:
|
103 |
+
tool_name = match[0].strip()
|
104 |
+
tool_input = match[1].strip().strip('"').strip("```")
|
105 |
+
try:
|
106 |
+
arguments = json.loads(tool_input)
|
107 |
+
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
108 |
+
except json.JSONDecodeError:
|
109 |
+
return content
|
110 |
+
|
111 |
+
return results
|
112 |
+
|
113 |
+
|
114 |
+
class GLM4ToolUtils(ToolUtils):
|
115 |
+
@staticmethod
|
116 |
+
def get_function_slots() -> SLOTS:
|
117 |
+
return ["{{name}}\n{{arguments}}"]
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
121 |
+
tool_text = ""
|
122 |
+
for tool in tools:
|
123 |
+
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
124 |
+
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
125 |
+
)
|
126 |
+
|
127 |
+
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
131 |
+
if "\n" not in content:
|
132 |
+
return content
|
133 |
+
|
134 |
+
tool_name, tool_input = content.split("\n", maxsplit=1)
|
135 |
+
try:
|
136 |
+
arguments = json.loads(tool_input)
|
137 |
+
except json.JSONDecodeError:
|
138 |
+
return content
|
139 |
+
|
140 |
+
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
llama-factory/src/llamafactory/eval/__init__.py
ADDED
File without changes
|
llama-factory/src/llamafactory/eval/evaluator.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the Dan's test library.
|
4 |
+
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
#
|
18 |
+
# MIT License
|
19 |
+
#
|
20 |
+
# Copyright (c) 2020 Dan Hendrycks
|
21 |
+
#
|
22 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
23 |
+
# of this software and associated documentation files (the "Software"), to deal
|
24 |
+
# in the Software without restriction, including without limitation the rights
|
25 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
26 |
+
# copies of the Software, and to permit persons to whom the Software is
|
27 |
+
# furnished to do so, subject to the following conditions:
|
28 |
+
#
|
29 |
+
# The above copyright notice and this permission notice shall be included in all
|
30 |
+
# copies or substantial portions of the Software.
|
31 |
+
#
|
32 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
33 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
34 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
35 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
36 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
37 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
38 |
+
# SOFTWARE.
|
39 |
+
|
40 |
+
import json
|
41 |
+
import os
|
42 |
+
from typing import Any, Dict, List, Optional
|
43 |
+
|
44 |
+
import numpy as np
|
45 |
+
import torch
|
46 |
+
from datasets import load_dataset
|
47 |
+
from tqdm import tqdm, trange
|
48 |
+
from transformers.utils import cached_file
|
49 |
+
|
50 |
+
from ..data import get_template_and_fix_tokenizer
|
51 |
+
from ..extras.constants import CHOICES, SUBJECTS
|
52 |
+
from ..hparams import get_eval_args
|
53 |
+
from ..model import load_model, load_tokenizer
|
54 |
+
from .template import get_eval_template
|
55 |
+
|
56 |
+
|
57 |
+
class Evaluator:
|
58 |
+
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
59 |
+
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
60 |
+
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
61 |
+
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
62 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
63 |
+
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
64 |
+
self.eval_template = get_eval_template(self.eval_args.lang)
|
65 |
+
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
69 |
+
logits = self.model(**batch_input).logits
|
70 |
+
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
71 |
+
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
72 |
+
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
|
73 |
+
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
74 |
+
|
75 |
+
def eval(self) -> None:
|
76 |
+
eval_task = self.eval_args.task.split("_")[0]
|
77 |
+
eval_split = self.eval_args.task.split("_")[1]
|
78 |
+
|
79 |
+
mapping = cached_file(
|
80 |
+
path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task),
|
81 |
+
filename="mapping.json",
|
82 |
+
cache_dir=self.model_args.cache_dir,
|
83 |
+
token=self.model_args.hf_hub_token,
|
84 |
+
)
|
85 |
+
|
86 |
+
with open(mapping, "r", encoding="utf-8") as f:
|
87 |
+
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
88 |
+
|
89 |
+
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
90 |
+
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
91 |
+
results = {}
|
92 |
+
for subject in pbar:
|
93 |
+
dataset = load_dataset(
|
94 |
+
path=os.path.join(self.eval_args.task_dir, eval_task),
|
95 |
+
name=subject,
|
96 |
+
cache_dir=self.model_args.cache_dir,
|
97 |
+
download_mode=self.eval_args.download_mode,
|
98 |
+
token=self.model_args.hf_hub_token,
|
99 |
+
trust_remote_code=True,
|
100 |
+
)
|
101 |
+
pbar.set_postfix_str(categorys[subject]["name"])
|
102 |
+
inputs, outputs, labels = [], [], []
|
103 |
+
for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False):
|
104 |
+
support_set = (
|
105 |
+
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
106 |
+
)
|
107 |
+
messages = self.eval_template.format_example(
|
108 |
+
target_data=dataset[eval_split][i],
|
109 |
+
support_set=support_set,
|
110 |
+
subject_name=categorys[subject]["name"],
|
111 |
+
)
|
112 |
+
|
113 |
+
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
|
114 |
+
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
115 |
+
labels.append(messages[-1]["content"])
|
116 |
+
|
117 |
+
for i in trange(
|
118 |
+
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
|
119 |
+
):
|
120 |
+
batch_input = self.tokenizer.pad(
|
121 |
+
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
122 |
+
).to(self.model.device)
|
123 |
+
preds = self.batch_inference(batch_input)
|
124 |
+
outputs += preds
|
125 |
+
|
126 |
+
corrects = np.array(outputs) == np.array(labels)
|
127 |
+
category_name = categorys[subject]["category"]
|
128 |
+
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
129 |
+
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
130 |
+
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
131 |
+
|
132 |
+
pbar.close()
|
133 |
+
self._save_results(category_corrects, results)
|
134 |
+
|
135 |
+
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
136 |
+
score_info = "\n".join(
|
137 |
+
[
|
138 |
+
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
139 |
+
for category_name, category_correct in category_corrects.items()
|
140 |
+
if len(category_correct)
|
141 |
+
]
|
142 |
+
)
|
143 |
+
print(score_info)
|
144 |
+
if self.eval_args.save_dir is not None:
|
145 |
+
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
146 |
+
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
|
147 |
+
json.dump(results, f, indent=2)
|
148 |
+
|
149 |
+
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
|
150 |
+
f.write(score_info)
|
151 |
+
|
152 |
+
|
153 |
+
def run_eval() -> None:
|
154 |
+
Evaluator().eval()
|
llama-factory/src/llamafactory/eval/template.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Dict, List, Sequence, Tuple
|
17 |
+
|
18 |
+
from ..data import Role
|
19 |
+
from ..extras.constants import CHOICES
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class EvalTemplate:
|
24 |
+
system: str
|
25 |
+
choice: str
|
26 |
+
answer: str
|
27 |
+
|
28 |
+
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
29 |
+
r"""
|
30 |
+
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
31 |
+
output: a tuple of (prompt, response)
|
32 |
+
"""
|
33 |
+
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
34 |
+
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
35 |
+
|
36 |
+
def format_example(
|
37 |
+
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
38 |
+
) -> List[Dict[str, str]]:
|
39 |
+
r"""
|
40 |
+
Converts dataset examples to messages.
|
41 |
+
"""
|
42 |
+
messages = []
|
43 |
+
for k in range(len(support_set)):
|
44 |
+
prompt, response = self._parse_example(support_set[k])
|
45 |
+
messages.append({"role": Role.USER.value, "content": prompt})
|
46 |
+
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
47 |
+
|
48 |
+
prompt, response = self._parse_example(target_data)
|
49 |
+
messages.append({"role": Role.USER.value, "content": prompt})
|
50 |
+
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
51 |
+
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
52 |
+
return messages
|
53 |
+
|
54 |
+
|
55 |
+
eval_templates: Dict[str, "EvalTemplate"] = {}
|
56 |
+
|
57 |
+
|
58 |
+
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
59 |
+
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
|
60 |
+
|
61 |
+
|
62 |
+
def get_eval_template(name: str) -> "EvalTemplate":
|
63 |
+
eval_template = eval_templates.get(name, None)
|
64 |
+
assert eval_template is not None, "Template {} does not exist.".format(name)
|
65 |
+
return eval_template
|
66 |
+
|
67 |
+
|
68 |
+
_register_eval_template(
|
69 |
+
name="en",
|
70 |
+
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
71 |
+
choice="\n{choice}. {content}",
|
72 |
+
answer="\nAnswer:",
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
_register_eval_template(
|
77 |
+
name="zh",
|
78 |
+
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
79 |
+
choice="\n{choice}. {content}",
|
80 |
+
answer="\n答案:",
|
81 |
+
)
|
llama-factory/src/llamafactory/extras/__init__.py
ADDED
File without changes
|
llama-factory/src/llamafactory/extras/constants.py
ADDED
@@ -0,0 +1,1590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import OrderedDict, defaultdict
|
16 |
+
from enum import Enum
|
17 |
+
from typing import Dict, Optional
|
18 |
+
|
19 |
+
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
20 |
+
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
21 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
22 |
+
|
23 |
+
|
24 |
+
CHECKPOINT_NAMES = {
|
25 |
+
SAFE_ADAPTER_WEIGHTS_NAME,
|
26 |
+
ADAPTER_WEIGHTS_NAME,
|
27 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
28 |
+
SAFE_WEIGHTS_NAME,
|
29 |
+
WEIGHTS_INDEX_NAME,
|
30 |
+
WEIGHTS_NAME,
|
31 |
+
}
|
32 |
+
|
33 |
+
CHOICES = ["A", "B", "C", "D"]
|
34 |
+
|
35 |
+
DATA_CONFIG = "dataset_info.json"
|
36 |
+
|
37 |
+
DEFAULT_TEMPLATE = defaultdict(str)
|
38 |
+
|
39 |
+
FILEEXT2TYPE = {
|
40 |
+
"arrow": "arrow",
|
41 |
+
"csv": "csv",
|
42 |
+
"json": "json",
|
43 |
+
"jsonl": "json",
|
44 |
+
"parquet": "parquet",
|
45 |
+
"txt": "text",
|
46 |
+
}
|
47 |
+
|
48 |
+
IGNORE_INDEX = -100
|
49 |
+
|
50 |
+
LAYERNORM_NAMES = {"norm", "ln"}
|
51 |
+
|
52 |
+
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
53 |
+
|
54 |
+
METHODS = ["full", "freeze", "lora"]
|
55 |
+
|
56 |
+
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
57 |
+
|
58 |
+
PEFT_METHODS = {"lora"}
|
59 |
+
|
60 |
+
RUNNING_LOG = "running_log.txt"
|
61 |
+
|
62 |
+
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
63 |
+
|
64 |
+
SUPPORTED_MODELS = OrderedDict()
|
65 |
+
|
66 |
+
TRAINER_LOG = "trainer_log.jsonl"
|
67 |
+
|
68 |
+
TRAINING_ARGS = "training_args.yaml"
|
69 |
+
|
70 |
+
TRAINING_STAGES = {
|
71 |
+
"Supervised Fine-Tuning": "sft",
|
72 |
+
"Reward Modeling": "rm",
|
73 |
+
"PPO": "ppo",
|
74 |
+
"DPO": "dpo",
|
75 |
+
"KTO": "kto",
|
76 |
+
"Pre-Training": "pt",
|
77 |
+
}
|
78 |
+
|
79 |
+
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
80 |
+
|
81 |
+
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
82 |
+
"cohere",
|
83 |
+
"falcon",
|
84 |
+
"gemma",
|
85 |
+
"gemma2",
|
86 |
+
"llama",
|
87 |
+
"mistral",
|
88 |
+
"phi",
|
89 |
+
"phi3",
|
90 |
+
"qwen2",
|
91 |
+
"starcoder2",
|
92 |
+
}
|
93 |
+
|
94 |
+
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
95 |
+
|
96 |
+
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
97 |
+
|
98 |
+
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
99 |
+
|
100 |
+
VISION_MODELS = set()
|
101 |
+
|
102 |
+
|
103 |
+
class DownloadSource(str, Enum):
|
104 |
+
DEFAULT = "hf"
|
105 |
+
MODELSCOPE = "ms"
|
106 |
+
|
107 |
+
|
108 |
+
def register_model_group(
|
109 |
+
models: Dict[str, Dict[DownloadSource, str]],
|
110 |
+
template: Optional[str] = None,
|
111 |
+
vision: bool = False,
|
112 |
+
) -> None:
|
113 |
+
prefix = None
|
114 |
+
for name, path in models.items():
|
115 |
+
if prefix is None:
|
116 |
+
prefix = name.split("-")[0]
|
117 |
+
else:
|
118 |
+
assert prefix == name.split("-")[0], "prefix should be identical."
|
119 |
+
SUPPORTED_MODELS[name] = path
|
120 |
+
if template is not None:
|
121 |
+
DEFAULT_TEMPLATE[prefix] = template
|
122 |
+
if vision:
|
123 |
+
VISION_MODELS.add(prefix)
|
124 |
+
|
125 |
+
|
126 |
+
register_model_group(
|
127 |
+
models={
|
128 |
+
"Aya-23-8B-Chat": {
|
129 |
+
DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
|
130 |
+
},
|
131 |
+
"Aya-23-35B-Chat": {
|
132 |
+
DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
|
133 |
+
},
|
134 |
+
},
|
135 |
+
template="cohere",
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
register_model_group(
|
140 |
+
models={
|
141 |
+
"Baichuan-7B-Base": {
|
142 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
143 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
144 |
+
},
|
145 |
+
"Baichuan-13B-Base": {
|
146 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
147 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
148 |
+
},
|
149 |
+
"Baichuan-13B-Chat": {
|
150 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
151 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
152 |
+
},
|
153 |
+
},
|
154 |
+
template="baichuan",
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
register_model_group(
|
159 |
+
models={
|
160 |
+
"Baichuan2-7B-Base": {
|
161 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
162 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
163 |
+
},
|
164 |
+
"Baichuan2-13B-Base": {
|
165 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
166 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
167 |
+
},
|
168 |
+
"Baichuan2-7B-Chat": {
|
169 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
170 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
171 |
+
},
|
172 |
+
"Baichuan2-13B-Chat": {
|
173 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
174 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
175 |
+
},
|
176 |
+
},
|
177 |
+
template="baichuan2",
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
register_model_group(
|
182 |
+
models={
|
183 |
+
"BLOOM-560M": {
|
184 |
+
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
185 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
|
186 |
+
},
|
187 |
+
"BLOOM-3B": {
|
188 |
+
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
189 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
|
190 |
+
},
|
191 |
+
"BLOOM-7B1": {
|
192 |
+
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
193 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
194 |
+
},
|
195 |
+
},
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
register_model_group(
|
200 |
+
models={
|
201 |
+
"BLOOMZ-560M": {
|
202 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
203 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
|
204 |
+
},
|
205 |
+
"BLOOMZ-3B": {
|
206 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
207 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
|
208 |
+
},
|
209 |
+
"BLOOMZ-7B1-mt": {
|
210 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
211 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
212 |
+
},
|
213 |
+
},
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
register_model_group(
|
218 |
+
models={
|
219 |
+
"BlueLM-7B-Base": {
|
220 |
+
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
221 |
+
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
222 |
+
},
|
223 |
+
"BlueLM-7B-Chat": {
|
224 |
+
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
225 |
+
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
226 |
+
},
|
227 |
+
},
|
228 |
+
template="bluelm",
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
register_model_group(
|
233 |
+
models={
|
234 |
+
"Breeze-7B": {
|
235 |
+
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
|
236 |
+
},
|
237 |
+
"Breeze-7B-Chat": {
|
238 |
+
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
|
239 |
+
},
|
240 |
+
},
|
241 |
+
template="breeze",
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
register_model_group(
|
246 |
+
models={
|
247 |
+
"ChatGLM2-6B-Chat": {
|
248 |
+
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
249 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
250 |
+
}
|
251 |
+
},
|
252 |
+
template="chatglm2",
|
253 |
+
)
|
254 |
+
|
255 |
+
|
256 |
+
register_model_group(
|
257 |
+
models={
|
258 |
+
"ChatGLM3-6B-Base": {
|
259 |
+
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
260 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
|
261 |
+
},
|
262 |
+
"ChatGLM3-6B-Chat": {
|
263 |
+
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
264 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
265 |
+
},
|
266 |
+
},
|
267 |
+
template="chatglm3",
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
register_model_group(
|
272 |
+
models={
|
273 |
+
"ChineseLLaMA2-1.3B": {
|
274 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
275 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
276 |
+
},
|
277 |
+
"ChineseLLaMA2-7B": {
|
278 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
279 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
280 |
+
},
|
281 |
+
"ChineseLLaMA2-13B": {
|
282 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
283 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
284 |
+
},
|
285 |
+
"ChineseLLaMA2-1.3B-Chat": {
|
286 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
287 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
288 |
+
},
|
289 |
+
"ChineseLLaMA2-7B-Chat": {
|
290 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
291 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
292 |
+
},
|
293 |
+
"ChineseLLaMA2-13B-Chat": {
|
294 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
295 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
296 |
+
},
|
297 |
+
},
|
298 |
+
template="llama2_zh",
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
register_model_group(
|
303 |
+
models={
|
304 |
+
"CodeGeeX4-9B-Chat": {
|
305 |
+
DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
|
306 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
|
307 |
+
},
|
308 |
+
},
|
309 |
+
template="codegeex4",
|
310 |
+
)
|
311 |
+
|
312 |
+
|
313 |
+
register_model_group(
|
314 |
+
models={
|
315 |
+
"CodeGemma-7B": {
|
316 |
+
DownloadSource.DEFAULT: "google/codegemma-7b",
|
317 |
+
},
|
318 |
+
"CodeGemma-7B-Chat": {
|
319 |
+
DownloadSource.DEFAULT: "google/codegemma-7b-it",
|
320 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
|
321 |
+
},
|
322 |
+
"CodeGemma-1.1-2B": {
|
323 |
+
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
|
324 |
+
},
|
325 |
+
"CodeGemma-1.1-7B-Chat": {
|
326 |
+
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
|
327 |
+
},
|
328 |
+
},
|
329 |
+
template="gemma",
|
330 |
+
)
|
331 |
+
|
332 |
+
|
333 |
+
register_model_group(
|
334 |
+
models={
|
335 |
+
"Codestral-22B-v0.1-Chat": {
|
336 |
+
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
|
337 |
+
},
|
338 |
+
},
|
339 |
+
template="mistral",
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
register_model_group(
|
344 |
+
models={
|
345 |
+
"CommandR-35B-Chat": {
|
346 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
|
347 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
|
348 |
+
},
|
349 |
+
"CommandR-Plus-104B-Chat": {
|
350 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
|
351 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
|
352 |
+
},
|
353 |
+
"CommandR-35B-4bit-Chat": {
|
354 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
|
355 |
+
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
|
356 |
+
},
|
357 |
+
"CommandR-Plus-104B-4bit-Chat": {
|
358 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
|
359 |
+
},
|
360 |
+
},
|
361 |
+
template="cohere",
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
register_model_group(
|
366 |
+
models={
|
367 |
+
"DBRX-132B-Base": {
|
368 |
+
DownloadSource.DEFAULT: "databricks/dbrx-base",
|
369 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
|
370 |
+
},
|
371 |
+
"DBRX-132B-Chat": {
|
372 |
+
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
|
373 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
|
374 |
+
},
|
375 |
+
},
|
376 |
+
template="dbrx",
|
377 |
+
)
|
378 |
+
|
379 |
+
|
380 |
+
register_model_group(
|
381 |
+
models={
|
382 |
+
"DeepSeek-LLM-7B-Base": {
|
383 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
384 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
|
385 |
+
},
|
386 |
+
"DeepSeek-LLM-67B-Base": {
|
387 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
388 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
|
389 |
+
},
|
390 |
+
"DeepSeek-LLM-7B-Chat": {
|
391 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
392 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
|
393 |
+
},
|
394 |
+
"DeepSeek-LLM-67B-Chat": {
|
395 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
396 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
|
397 |
+
},
|
398 |
+
"DeepSeek-Math-7B-Base": {
|
399 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
|
400 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
|
401 |
+
},
|
402 |
+
"DeepSeek-Math-7B-Chat": {
|
403 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
|
404 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
|
405 |
+
},
|
406 |
+
"DeepSeek-MoE-16B-Base": {
|
407 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
408 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
409 |
+
},
|
410 |
+
"DeepSeek-MoE-16B-v2-Base": {
|
411 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
|
412 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
|
413 |
+
},
|
414 |
+
"DeepSeek-MoE-236B-Base": {
|
415 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
|
416 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
|
417 |
+
},
|
418 |
+
"DeepSeek-MoE-16B-Chat": {
|
419 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
420 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
421 |
+
},
|
422 |
+
"DeepSeek-MoE-16B-v2-Chat": {
|
423 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
424 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
425 |
+
},
|
426 |
+
"DeepSeek-MoE-236B-Chat": {
|
427 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
428 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
429 |
+
},
|
430 |
+
"DeepSeek-MoE-Coder-16B-Base": {
|
431 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
|
432 |
+
},
|
433 |
+
"DeepSeek-MoE-Coder-236B-Base": {
|
434 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
|
435 |
+
},
|
436 |
+
"DeepSeek-MoE-Coder-16B-Chat": {
|
437 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
438 |
+
},
|
439 |
+
"DeepSeek-MoE-Coder-236B-Chat": {
|
440 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
441 |
+
},
|
442 |
+
},
|
443 |
+
template="deepseek",
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
register_model_group(
|
448 |
+
models={
|
449 |
+
"DeepSeekCoder-6.7B-Base": {
|
450 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
451 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
452 |
+
},
|
453 |
+
"DeepSeekCoder-7B-Base": {
|
454 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
455 |
+
},
|
456 |
+
"DeepSeekCoder-33B-Base": {
|
457 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
458 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
459 |
+
},
|
460 |
+
"DeepSeekCoder-6.7B-Chat": {
|
461 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
462 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
463 |
+
},
|
464 |
+
"DeepSeekCoder-7B-Chat": {
|
465 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
466 |
+
},
|
467 |
+
"DeepSeekCoder-33B-Chat": {
|
468 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
469 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
470 |
+
},
|
471 |
+
},
|
472 |
+
template="deepseekcoder",
|
473 |
+
)
|
474 |
+
|
475 |
+
|
476 |
+
register_model_group(
|
477 |
+
models={
|
478 |
+
"Falcon-7B": {
|
479 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
480 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
|
481 |
+
},
|
482 |
+
"Falcon-11B": {
|
483 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
|
484 |
+
},
|
485 |
+
"Falcon-40B": {
|
486 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
487 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
|
488 |
+
},
|
489 |
+
"Falcon-180B": {
|
490 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
491 |
+
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
|
492 |
+
},
|
493 |
+
"Falcon-7B-Chat": {
|
494 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
495 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
|
496 |
+
},
|
497 |
+
"Falcon-40B-Chat": {
|
498 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
499 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
|
500 |
+
},
|
501 |
+
"Falcon-180B-Chat": {
|
502 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
503 |
+
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
504 |
+
},
|
505 |
+
},
|
506 |
+
template="falcon",
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
register_model_group(
|
511 |
+
models={
|
512 |
+
"Gemma-2B": {
|
513 |
+
DownloadSource.DEFAULT: "google/gemma-2b",
|
514 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
|
515 |
+
},
|
516 |
+
"Gemma-7B": {
|
517 |
+
DownloadSource.DEFAULT: "google/gemma-7b",
|
518 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
|
519 |
+
},
|
520 |
+
"Gemma-2B-Chat": {
|
521 |
+
DownloadSource.DEFAULT: "google/gemma-2b-it",
|
522 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
|
523 |
+
},
|
524 |
+
"Gemma-7B-Chat": {
|
525 |
+
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
526 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
527 |
+
},
|
528 |
+
"Gemma-1.1-2B-Chat": {
|
529 |
+
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
|
530 |
+
},
|
531 |
+
"Gemma-1.1-7B-Chat": {
|
532 |
+
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
533 |
+
},
|
534 |
+
"Gemma-2-9B": {
|
535 |
+
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
536 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
537 |
+
},
|
538 |
+
"Gemma-2-27B": {
|
539 |
+
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
540 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
541 |
+
},
|
542 |
+
"Gemma-2-9B-Chat": {
|
543 |
+
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
544 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
545 |
+
},
|
546 |
+
"Gemma-2-27B-Chat": {
|
547 |
+
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
548 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
|
549 |
+
},
|
550 |
+
},
|
551 |
+
template="gemma",
|
552 |
+
)
|
553 |
+
|
554 |
+
|
555 |
+
register_model_group(
|
556 |
+
models={
|
557 |
+
"GLM-4-9B": {
|
558 |
+
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
|
559 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
|
560 |
+
},
|
561 |
+
"GLM-4-9B-Chat": {
|
562 |
+
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
563 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
564 |
+
},
|
565 |
+
"GLM-4-9B-1M-Chat": {
|
566 |
+
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
567 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
|
568 |
+
},
|
569 |
+
},
|
570 |
+
template="glm4",
|
571 |
+
)
|
572 |
+
|
573 |
+
|
574 |
+
register_model_group(
|
575 |
+
models={
|
576 |
+
"InternLM-7B": {
|
577 |
+
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
578 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
579 |
+
},
|
580 |
+
"InternLM-20B": {
|
581 |
+
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
582 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
583 |
+
},
|
584 |
+
"InternLM-7B-Chat": {
|
585 |
+
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
586 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
587 |
+
},
|
588 |
+
"InternLM-20B-Chat": {
|
589 |
+
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
590 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
591 |
+
},
|
592 |
+
},
|
593 |
+
template="intern",
|
594 |
+
)
|
595 |
+
|
596 |
+
|
597 |
+
register_model_group(
|
598 |
+
models={
|
599 |
+
"InternLM2-7B": {
|
600 |
+
DownloadSource.DEFAULT: "internlm/internlm2-7b",
|
601 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
|
602 |
+
},
|
603 |
+
"InternLM2-20B": {
|
604 |
+
DownloadSource.DEFAULT: "internlm/internlm2-20b",
|
605 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
|
606 |
+
},
|
607 |
+
"InternLM2-7B-Chat": {
|
608 |
+
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
|
609 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
|
610 |
+
},
|
611 |
+
"InternLM2-20B-Chat": {
|
612 |
+
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
613 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
614 |
+
},
|
615 |
+
},
|
616 |
+
template="intern2",
|
617 |
+
)
|
618 |
+
|
619 |
+
|
620 |
+
register_model_group(
|
621 |
+
models={
|
622 |
+
"InternLM2.5-7B": {
|
623 |
+
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
624 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
|
625 |
+
},
|
626 |
+
"InternLM2.5-7B-Chat": {
|
627 |
+
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
628 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
629 |
+
},
|
630 |
+
"InternLM2.5-7B-1M-Chat": {
|
631 |
+
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
632 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
633 |
+
},
|
634 |
+
},
|
635 |
+
template="intern2",
|
636 |
+
)
|
637 |
+
|
638 |
+
|
639 |
+
register_model_group(
|
640 |
+
models={
|
641 |
+
"Jamba-v0.1": {
|
642 |
+
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
643 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
644 |
+
}
|
645 |
+
},
|
646 |
+
)
|
647 |
+
|
648 |
+
|
649 |
+
register_model_group(
|
650 |
+
models={
|
651 |
+
"LingoWhale-8B": {
|
652 |
+
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
653 |
+
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
654 |
+
}
|
655 |
+
},
|
656 |
+
)
|
657 |
+
|
658 |
+
|
659 |
+
register_model_group(
|
660 |
+
models={
|
661 |
+
"LLaMA-7B": {
|
662 |
+
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
663 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
664 |
+
},
|
665 |
+
"LLaMA-13B": {
|
666 |
+
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
667 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
668 |
+
},
|
669 |
+
"LLaMA-30B": {
|
670 |
+
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
671 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
672 |
+
},
|
673 |
+
"LLaMA-65B": {
|
674 |
+
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
675 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
676 |
+
},
|
677 |
+
}
|
678 |
+
)
|
679 |
+
|
680 |
+
|
681 |
+
register_model_group(
|
682 |
+
models={
|
683 |
+
"LLaMA2-7B": {
|
684 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
685 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
686 |
+
},
|
687 |
+
"LLaMA2-13B": {
|
688 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
689 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
690 |
+
},
|
691 |
+
"LLaMA2-70B": {
|
692 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
693 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
694 |
+
},
|
695 |
+
"LLaMA2-7B-Chat": {
|
696 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
697 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
698 |
+
},
|
699 |
+
"LLaMA2-13B-Chat": {
|
700 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
701 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
702 |
+
},
|
703 |
+
"LLaMA2-70B-Chat": {
|
704 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
705 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
706 |
+
},
|
707 |
+
},
|
708 |
+
template="llama2",
|
709 |
+
)
|
710 |
+
|
711 |
+
|
712 |
+
register_model_group(
|
713 |
+
models={
|
714 |
+
"LLaMA3-8B": {
|
715 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
716 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
717 |
+
},
|
718 |
+
"LLaMA3-70B": {
|
719 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
720 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
721 |
+
},
|
722 |
+
"LLaMA3-8B-Chat": {
|
723 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
724 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
725 |
+
},
|
726 |
+
"LLaMA3-70B-Chat": {
|
727 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
728 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
729 |
+
},
|
730 |
+
"LLaMA3-8B-Chinese-Chat": {
|
731 |
+
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
732 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
733 |
+
},
|
734 |
+
"LLaMA3-70B-Chinese-Chat": {
|
735 |
+
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
736 |
+
},
|
737 |
+
},
|
738 |
+
template="llama3",
|
739 |
+
)
|
740 |
+
|
741 |
+
|
742 |
+
register_model_group(
|
743 |
+
models={
|
744 |
+
"LLaVA1.5-7B-Chat": {
|
745 |
+
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
746 |
+
},
|
747 |
+
"LLaVA1.5-13B-Chat": {
|
748 |
+
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
749 |
+
},
|
750 |
+
},
|
751 |
+
template="vicuna",
|
752 |
+
vision=True,
|
753 |
+
)
|
754 |
+
|
755 |
+
|
756 |
+
register_model_group(
|
757 |
+
models={
|
758 |
+
"MiniCPM-2B-SFT-Chat": {
|
759 |
+
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
760 |
+
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
761 |
+
},
|
762 |
+
"MiniCPM-2B-DPO-Chat": {
|
763 |
+
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
764 |
+
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
765 |
+
},
|
766 |
+
},
|
767 |
+
template="cpm",
|
768 |
+
)
|
769 |
+
|
770 |
+
|
771 |
+
register_model_group(
|
772 |
+
models={
|
773 |
+
"Mistral-7B-v0.1": {
|
774 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
775 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
|
776 |
+
},
|
777 |
+
"Mistral-7B-v0.1-Chat": {
|
778 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
779 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
780 |
+
},
|
781 |
+
"Mistral-7B-v0.2": {
|
782 |
+
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
|
783 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
|
784 |
+
},
|
785 |
+
"Mistral-7B-v0.2-Chat": {
|
786 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
787 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
|
788 |
+
},
|
789 |
+
"Mistral-7B-v0.3": {
|
790 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
|
791 |
+
},
|
792 |
+
"Mistral-7B-v0.3-Chat": {
|
793 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
|
794 |
+
},
|
795 |
+
},
|
796 |
+
template="mistral",
|
797 |
+
)
|
798 |
+
|
799 |
+
|
800 |
+
register_model_group(
|
801 |
+
models={
|
802 |
+
"Mixtral-8x7B-v0.1": {
|
803 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
804 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
805 |
+
},
|
806 |
+
"Mixtral-8x7B-v0.1-Chat": {
|
807 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
808 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
809 |
+
},
|
810 |
+
"Mixtral-8x22B-v0.1": {
|
811 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
|
812 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
|
813 |
+
},
|
814 |
+
"Mixtral-8x22B-v0.1-Chat": {
|
815 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
816 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
|
817 |
+
},
|
818 |
+
},
|
819 |
+
template="mistral",
|
820 |
+
)
|
821 |
+
|
822 |
+
|
823 |
+
register_model_group(
|
824 |
+
models={
|
825 |
+
"OLMo-1B": {
|
826 |
+
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
|
827 |
+
},
|
828 |
+
"OLMo-7B": {
|
829 |
+
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
|
830 |
+
},
|
831 |
+
"OLMo-7B-Chat": {
|
832 |
+
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
|
833 |
+
},
|
834 |
+
"OLMo-1.7-7B": {
|
835 |
+
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
|
836 |
+
},
|
837 |
+
},
|
838 |
+
)
|
839 |
+
|
840 |
+
|
841 |
+
register_model_group(
|
842 |
+
models={
|
843 |
+
"OpenChat3.5-7B-Chat": {
|
844 |
+
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
|
845 |
+
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
|
846 |
+
}
|
847 |
+
},
|
848 |
+
template="openchat",
|
849 |
+
)
|
850 |
+
|
851 |
+
|
852 |
+
register_model_group(
|
853 |
+
models={
|
854 |
+
"OpenChat3.6-8B-Chat": {
|
855 |
+
DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
|
856 |
+
}
|
857 |
+
},
|
858 |
+
template="openchat-3.6",
|
859 |
+
)
|
860 |
+
|
861 |
+
|
862 |
+
register_model_group(
|
863 |
+
models={
|
864 |
+
"Orion-14B-Base": {
|
865 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
|
866 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
|
867 |
+
},
|
868 |
+
"Orion-14B-Chat": {
|
869 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
|
870 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
|
871 |
+
},
|
872 |
+
"Orion-14B-Long-Chat": {
|
873 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
|
874 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
|
875 |
+
},
|
876 |
+
"Orion-14B-RAG-Chat": {
|
877 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
|
878 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
|
879 |
+
},
|
880 |
+
"Orion-14B-Plugin-Chat": {
|
881 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
|
882 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
|
883 |
+
},
|
884 |
+
},
|
885 |
+
template="orion",
|
886 |
+
)
|
887 |
+
|
888 |
+
|
889 |
+
register_model_group(
|
890 |
+
models={
|
891 |
+
"PaliGemma-3B-pt-224": {
|
892 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
|
893 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
|
894 |
+
},
|
895 |
+
"PaliGemma-3B-pt-448": {
|
896 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
|
897 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
|
898 |
+
},
|
899 |
+
"PaliGemma-3B-pt-896": {
|
900 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
|
901 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
|
902 |
+
},
|
903 |
+
"PaliGemma-3B-mix-224": {
|
904 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
|
905 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
|
906 |
+
},
|
907 |
+
"PaliGemma-3B-mix-448": {
|
908 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
|
909 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
|
910 |
+
},
|
911 |
+
},
|
912 |
+
vision=True,
|
913 |
+
)
|
914 |
+
|
915 |
+
|
916 |
+
register_model_group(
|
917 |
+
models={
|
918 |
+
"Phi-1.5-1.3B": {
|
919 |
+
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
920 |
+
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
|
921 |
+
},
|
922 |
+
"Phi-2-2.7B": {
|
923 |
+
DownloadSource.DEFAULT: "microsoft/phi-2",
|
924 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
|
925 |
+
},
|
926 |
+
}
|
927 |
+
)
|
928 |
+
|
929 |
+
|
930 |
+
register_model_group(
|
931 |
+
models={
|
932 |
+
"Phi3-4B-4k-Chat": {
|
933 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
934 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
935 |
+
},
|
936 |
+
"Phi3-4B-128k-Chat": {
|
937 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
938 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
939 |
+
},
|
940 |
+
"Phi3-7B-8k-Chat": {
|
941 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
942 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
943 |
+
},
|
944 |
+
"Phi3-7B-128k-Chat": {
|
945 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
946 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
947 |
+
},
|
948 |
+
"Phi3-14B-8k-Chat": {
|
949 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
950 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
951 |
+
},
|
952 |
+
"Phi3-14B-128k-Chat": {
|
953 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
954 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
955 |
+
},
|
956 |
+
},
|
957 |
+
template="phi",
|
958 |
+
)
|
959 |
+
|
960 |
+
|
961 |
+
register_model_group(
|
962 |
+
models={
|
963 |
+
"Qwen-1.8B": {
|
964 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
965 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
|
966 |
+
},
|
967 |
+
"Qwen-7B": {
|
968 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
969 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
|
970 |
+
},
|
971 |
+
"Qwen-14B": {
|
972 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
973 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
|
974 |
+
},
|
975 |
+
"Qwen-72B": {
|
976 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
977 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
|
978 |
+
},
|
979 |
+
"Qwen-1.8B-Chat": {
|
980 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
981 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
982 |
+
},
|
983 |
+
"Qwen-7B-Chat": {
|
984 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
985 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
|
986 |
+
},
|
987 |
+
"Qwen-14B-Chat": {
|
988 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
989 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
990 |
+
},
|
991 |
+
"Qwen-72B-Chat": {
|
992 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
993 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
994 |
+
},
|
995 |
+
"Qwen-1.8B-int8-Chat": {
|
996 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
997 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
998 |
+
},
|
999 |
+
"Qwen-1.8B-int4-Chat": {
|
1000 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
1001 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
1002 |
+
},
|
1003 |
+
"Qwen-7B-int8-Chat": {
|
1004 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
1005 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
1006 |
+
},
|
1007 |
+
"Qwen-7B-int4-Chat": {
|
1008 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
1009 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
1010 |
+
},
|
1011 |
+
"Qwen-14B-int8-Chat": {
|
1012 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
1013 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
1014 |
+
},
|
1015 |
+
"Qwen-14B-int4-Chat": {
|
1016 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
1017 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
1018 |
+
},
|
1019 |
+
"Qwen-72B-int8-Chat": {
|
1020 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
1021 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
1022 |
+
},
|
1023 |
+
"Qwen-72B-int4-Chat": {
|
1024 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
1025 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
1026 |
+
},
|
1027 |
+
},
|
1028 |
+
template="qwen",
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
|
1032 |
+
register_model_group(
|
1033 |
+
models={
|
1034 |
+
"Qwen1.5-0.5B": {
|
1035 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
|
1036 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
|
1037 |
+
},
|
1038 |
+
"Qwen1.5-1.8B": {
|
1039 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
|
1040 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
|
1041 |
+
},
|
1042 |
+
"Qwen1.5-4B": {
|
1043 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
|
1044 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
|
1045 |
+
},
|
1046 |
+
"Qwen1.5-7B": {
|
1047 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
|
1048 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
|
1049 |
+
},
|
1050 |
+
"Qwen1.5-14B": {
|
1051 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
1052 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
|
1053 |
+
},
|
1054 |
+
"Qwen1.5-32B": {
|
1055 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
|
1056 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
|
1057 |
+
},
|
1058 |
+
"Qwen1.5-72B": {
|
1059 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
1060 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
|
1061 |
+
},
|
1062 |
+
"Qwen1.5-110B": {
|
1063 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
|
1064 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
|
1065 |
+
},
|
1066 |
+
"Qwen1.5-MoE-A2.7B": {
|
1067 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
1068 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
|
1069 |
+
},
|
1070 |
+
"Qwen1.5-Code-7B": {
|
1071 |
+
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
1072 |
+
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
1073 |
+
},
|
1074 |
+
"Qwen1.5-0.5B-Chat": {
|
1075 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
1076 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
1077 |
+
},
|
1078 |
+
"Qwen1.5-1.8B-Chat": {
|
1079 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
|
1080 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
|
1081 |
+
},
|
1082 |
+
"Qwen1.5-4B-Chat": {
|
1083 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
|
1084 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
|
1085 |
+
},
|
1086 |
+
"Qwen1.5-7B-Chat": {
|
1087 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
|
1088 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
|
1089 |
+
},
|
1090 |
+
"Qwen1.5-14B-Chat": {
|
1091 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
1092 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
|
1093 |
+
},
|
1094 |
+
"Qwen1.5-32B-Chat": {
|
1095 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
|
1096 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
|
1097 |
+
},
|
1098 |
+
"Qwen1.5-72B-Chat": {
|
1099 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
1100 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
|
1101 |
+
},
|
1102 |
+
"Qwen1.5-110B-Chat": {
|
1103 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
|
1104 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
|
1105 |
+
},
|
1106 |
+
"Qwen1.5-MoE-A2.7B-Chat": {
|
1107 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
1108 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
1109 |
+
},
|
1110 |
+
"Qwen1.5-Code-7B-Chat": {
|
1111 |
+
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
1112 |
+
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
1113 |
+
},
|
1114 |
+
"Qwen1.5-0.5B-int8-Chat": {
|
1115 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
1116 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
1117 |
+
},
|
1118 |
+
"Qwen1.5-0.5B-int4-Chat": {
|
1119 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
1120 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
1121 |
+
},
|
1122 |
+
"Qwen1.5-1.8B-int8-Chat": {
|
1123 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
1124 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
1125 |
+
},
|
1126 |
+
"Qwen1.5-1.8B-int4-Chat": {
|
1127 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
1128 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
1129 |
+
},
|
1130 |
+
"Qwen1.5-4B-int8-Chat": {
|
1131 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
1132 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
1133 |
+
},
|
1134 |
+
"Qwen1.5-4B-int4-Chat": {
|
1135 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
1136 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
1137 |
+
},
|
1138 |
+
"Qwen1.5-7B-int8-Chat": {
|
1139 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
1140 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
1141 |
+
},
|
1142 |
+
"Qwen1.5-7B-int4-Chat": {
|
1143 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
1144 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
1145 |
+
},
|
1146 |
+
"Qwen1.5-14B-int8-Chat": {
|
1147 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
1148 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
1149 |
+
},
|
1150 |
+
"Qwen1.5-14B-int4-Chat": {
|
1151 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
1152 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
1153 |
+
},
|
1154 |
+
"Qwen1.5-32B-int4-Chat": {
|
1155 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
1156 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
1157 |
+
},
|
1158 |
+
"Qwen1.5-72B-int8-Chat": {
|
1159 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
1160 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
1161 |
+
},
|
1162 |
+
"Qwen1.5-72B-int4-Chat": {
|
1163 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
1164 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
1165 |
+
},
|
1166 |
+
"Qwen1.5-110B-int4-Chat": {
|
1167 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
1168 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
1169 |
+
},
|
1170 |
+
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
1171 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
1172 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
1173 |
+
},
|
1174 |
+
"Qwen1.5-Code-7B-int4-Chat": {
|
1175 |
+
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
1176 |
+
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
1177 |
+
},
|
1178 |
+
},
|
1179 |
+
template="qwen",
|
1180 |
+
)
|
1181 |
+
|
1182 |
+
|
1183 |
+
register_model_group(
|
1184 |
+
models={
|
1185 |
+
"Qwen2-0.5B": {
|
1186 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
|
1187 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
|
1188 |
+
},
|
1189 |
+
"Qwen2-1.5B": {
|
1190 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
|
1191 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
|
1192 |
+
},
|
1193 |
+
"Qwen2-7B": {
|
1194 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
|
1195 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
|
1196 |
+
},
|
1197 |
+
"Qwen2-72B": {
|
1198 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
|
1199 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
|
1200 |
+
},
|
1201 |
+
"Qwen2-MoE-57B": {
|
1202 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
|
1203 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
|
1204 |
+
},
|
1205 |
+
"Qwen2-0.5B-Chat": {
|
1206 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
1207 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
1208 |
+
},
|
1209 |
+
"Qwen2-1.5B-Chat": {
|
1210 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
1211 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
1212 |
+
},
|
1213 |
+
"Qwen2-7B-Chat": {
|
1214 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
1215 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
1216 |
+
},
|
1217 |
+
"Qwen2-72B-Chat": {
|
1218 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
1219 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
|
1220 |
+
},
|
1221 |
+
"Qwen2-MoE-57B-Chat": {
|
1222 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
|
1223 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
|
1224 |
+
},
|
1225 |
+
"Qwen2-0.5B-int8-Chat": {
|
1226 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
|
1227 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
|
1228 |
+
},
|
1229 |
+
"Qwen2-0.5B-int4-Chat": {
|
1230 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
|
1231 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
|
1232 |
+
},
|
1233 |
+
"Qwen2-1.5B-int8-Chat": {
|
1234 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
|
1235 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
|
1236 |
+
},
|
1237 |
+
"Qwen2-1.5B-int4-Chat": {
|
1238 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
|
1239 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
|
1240 |
+
},
|
1241 |
+
"Qwen2-7B-int8-Chat": {
|
1242 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
|
1243 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
|
1244 |
+
},
|
1245 |
+
"Qwen2-7B-int4-Chat": {
|
1246 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
|
1247 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
|
1248 |
+
},
|
1249 |
+
"Qwen2-72B-int8-Chat": {
|
1250 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
|
1251 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
|
1252 |
+
},
|
1253 |
+
"Qwen2-72B-int4-Chat": {
|
1254 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
|
1255 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
|
1256 |
+
},
|
1257 |
+
"Qwen2-MoE-57B-int4-Chat": {
|
1258 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
|
1259 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
|
1260 |
+
},
|
1261 |
+
},
|
1262 |
+
template="qwen",
|
1263 |
+
)
|
1264 |
+
|
1265 |
+
|
1266 |
+
register_model_group(
|
1267 |
+
models={
|
1268 |
+
"SOLAR-10.7B": {
|
1269 |
+
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
|
1270 |
+
},
|
1271 |
+
"SOLAR-10.7B-Chat": {
|
1272 |
+
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
1273 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
|
1274 |
+
},
|
1275 |
+
},
|
1276 |
+
template="solar",
|
1277 |
+
)
|
1278 |
+
|
1279 |
+
|
1280 |
+
register_model_group(
|
1281 |
+
models={
|
1282 |
+
"Skywork-13B-Base": {
|
1283 |
+
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
1284 |
+
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
1285 |
+
}
|
1286 |
+
}
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
|
1290 |
+
register_model_group(
|
1291 |
+
models={
|
1292 |
+
"StarCoder2-3B": {
|
1293 |
+
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
|
1294 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
|
1295 |
+
},
|
1296 |
+
"StarCoder2-7B": {
|
1297 |
+
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
|
1298 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
|
1299 |
+
},
|
1300 |
+
"StarCoder2-15B": {
|
1301 |
+
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
|
1302 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
|
1303 |
+
},
|
1304 |
+
}
|
1305 |
+
)
|
1306 |
+
|
1307 |
+
|
1308 |
+
register_model_group(
|
1309 |
+
models={
|
1310 |
+
"TeleChat-1B-Chat": {
|
1311 |
+
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
|
1312 |
+
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
|
1313 |
+
},
|
1314 |
+
"TeleChat-7B-Chat": {
|
1315 |
+
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
1316 |
+
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
1317 |
+
},
|
1318 |
+
"TeleChat-12B-Chat": {
|
1319 |
+
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
1320 |
+
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
1321 |
+
},
|
1322 |
+
"TeleChat-12B-v2-Chat": {
|
1323 |
+
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
1324 |
+
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
|
1325 |
+
},
|
1326 |
+
},
|
1327 |
+
template="telechat",
|
1328 |
+
)
|
1329 |
+
|
1330 |
+
|
1331 |
+
register_model_group(
|
1332 |
+
models={
|
1333 |
+
"Vicuna1.5-7B-Chat": {
|
1334 |
+
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
1335 |
+
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
1336 |
+
},
|
1337 |
+
"Vicuna1.5-13B-Chat": {
|
1338 |
+
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
1339 |
+
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
1340 |
+
},
|
1341 |
+
},
|
1342 |
+
template="vicuna",
|
1343 |
+
)
|
1344 |
+
|
1345 |
+
|
1346 |
+
register_model_group(
|
1347 |
+
models={
|
1348 |
+
"XuanYuan-6B": {
|
1349 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
|
1350 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
|
1351 |
+
},
|
1352 |
+
"XuanYuan-70B": {
|
1353 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
1354 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
1355 |
+
},
|
1356 |
+
"XuanYuan-2-70B": {
|
1357 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
1358 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
1359 |
+
},
|
1360 |
+
"XuanYuan-6B-Chat": {
|
1361 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
1362 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
1363 |
+
},
|
1364 |
+
"XuanYuan-70B-Chat": {
|
1365 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
1366 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
1367 |
+
},
|
1368 |
+
"XuanYuan-2-70B-Chat": {
|
1369 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
1370 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
1371 |
+
},
|
1372 |
+
"XuanYuan-6B-int8-Chat": {
|
1373 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
1374 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
1375 |
+
},
|
1376 |
+
"XuanYuan-6B-int4-Chat": {
|
1377 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
1378 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
1379 |
+
},
|
1380 |
+
"XuanYuan-70B-int8-Chat": {
|
1381 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
1382 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
1383 |
+
},
|
1384 |
+
"XuanYuan-70B-int4-Chat": {
|
1385 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
1386 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
1387 |
+
},
|
1388 |
+
"XuanYuan-2-70B-int8-Chat": {
|
1389 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
1390 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
1391 |
+
},
|
1392 |
+
"XuanYuan-2-70B-int4-Chat": {
|
1393 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
1394 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
1395 |
+
},
|
1396 |
+
},
|
1397 |
+
template="xuanyuan",
|
1398 |
+
)
|
1399 |
+
|
1400 |
+
|
1401 |
+
register_model_group(
|
1402 |
+
models={
|
1403 |
+
"XVERSE-7B": {
|
1404 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
1405 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
|
1406 |
+
},
|
1407 |
+
"XVERSE-13B": {
|
1408 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
1409 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
|
1410 |
+
},
|
1411 |
+
"XVERSE-65B": {
|
1412 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
1413 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
|
1414 |
+
},
|
1415 |
+
"XVERSE-65B-2": {
|
1416 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
1417 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
1418 |
+
},
|
1419 |
+
"XVERSE-7B-Chat": {
|
1420 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
1421 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
1422 |
+
},
|
1423 |
+
"XVERSE-13B-Chat": {
|
1424 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
1425 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
1426 |
+
},
|
1427 |
+
"XVERSE-65B-Chat": {
|
1428 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
1429 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
1430 |
+
},
|
1431 |
+
"XVERSE-MoE-A4.2B": {
|
1432 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
|
1433 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
|
1434 |
+
},
|
1435 |
+
"XVERSE-7B-int8-Chat": {
|
1436 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
1437 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
1438 |
+
},
|
1439 |
+
"XVERSE-7B-int4-Chat": {
|
1440 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
1441 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
1442 |
+
},
|
1443 |
+
"XVERSE-13B-int8-Chat": {
|
1444 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
1445 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
1446 |
+
},
|
1447 |
+
"XVERSE-13B-int4-Chat": {
|
1448 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
1449 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
1450 |
+
},
|
1451 |
+
"XVERSE-65B-int4-Chat": {
|
1452 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
1453 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
1454 |
+
},
|
1455 |
+
},
|
1456 |
+
template="xverse",
|
1457 |
+
)
|
1458 |
+
|
1459 |
+
|
1460 |
+
register_model_group(
|
1461 |
+
models={
|
1462 |
+
"Yayi-7B": {
|
1463 |
+
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
1464 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
1465 |
+
},
|
1466 |
+
"Yayi-13B": {
|
1467 |
+
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
1468 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
1469 |
+
},
|
1470 |
+
},
|
1471 |
+
template="yayi",
|
1472 |
+
)
|
1473 |
+
|
1474 |
+
|
1475 |
+
register_model_group(
|
1476 |
+
models={
|
1477 |
+
"Yi-6B": {
|
1478 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
1479 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
|
1480 |
+
},
|
1481 |
+
"Yi-9B": {
|
1482 |
+
DownloadSource.DEFAULT: "01-ai/Yi-9B",
|
1483 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-9B",
|
1484 |
+
},
|
1485 |
+
"Yi-34B": {
|
1486 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
1487 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
|
1488 |
+
},
|
1489 |
+
"Yi-6B-Chat": {
|
1490 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
1491 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
|
1492 |
+
},
|
1493 |
+
"Yi-34B-Chat": {
|
1494 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
1495 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
1496 |
+
},
|
1497 |
+
"Yi-6B-int8-Chat": {
|
1498 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
1499 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
1500 |
+
},
|
1501 |
+
"Yi-6B-int4-Chat": {
|
1502 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
1503 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
1504 |
+
},
|
1505 |
+
"Yi-34B-int8-Chat": {
|
1506 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
1507 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
1508 |
+
},
|
1509 |
+
"Yi-34B-int4-Chat": {
|
1510 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
1511 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
1512 |
+
},
|
1513 |
+
"Yi-1.5-6B": {
|
1514 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
|
1515 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
|
1516 |
+
},
|
1517 |
+
"Yi-1.5-9B": {
|
1518 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
|
1519 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
|
1520 |
+
},
|
1521 |
+
"Yi-1.5-34B": {
|
1522 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
|
1523 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
|
1524 |
+
},
|
1525 |
+
"Yi-1.5-6B-Chat": {
|
1526 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
1527 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
1528 |
+
},
|
1529 |
+
"Yi-1.5-9B-Chat": {
|
1530 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
1531 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
|
1532 |
+
},
|
1533 |
+
"Yi-1.5-34B-Chat": {
|
1534 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
|
1535 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
|
1536 |
+
},
|
1537 |
+
},
|
1538 |
+
template="yi",
|
1539 |
+
)
|
1540 |
+
|
1541 |
+
|
1542 |
+
register_model_group(
|
1543 |
+
models={
|
1544 |
+
"YiVL-6B-Chat": {
|
1545 |
+
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
1546 |
+
},
|
1547 |
+
"YiVL-34B-Chat": {
|
1548 |
+
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
1549 |
+
},
|
1550 |
+
},
|
1551 |
+
template="yi_vl",
|
1552 |
+
vision=True,
|
1553 |
+
)
|
1554 |
+
|
1555 |
+
|
1556 |
+
register_model_group(
|
1557 |
+
models={
|
1558 |
+
"Yuan2-2B-Chat": {
|
1559 |
+
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
|
1560 |
+
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
|
1561 |
+
},
|
1562 |
+
"Yuan2-51B-Chat": {
|
1563 |
+
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
|
1564 |
+
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
|
1565 |
+
},
|
1566 |
+
"Yuan2-102B-Chat": {
|
1567 |
+
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
|
1568 |
+
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
|
1569 |
+
},
|
1570 |
+
},
|
1571 |
+
template="yuan",
|
1572 |
+
)
|
1573 |
+
|
1574 |
+
|
1575 |
+
register_model_group(
|
1576 |
+
models={
|
1577 |
+
"Zephyr-7B-Alpha-Chat": {
|
1578 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
1579 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
|
1580 |
+
},
|
1581 |
+
"Zephyr-7B-Beta-Chat": {
|
1582 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
1583 |
+
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
1584 |
+
},
|
1585 |
+
"Zephyr-141B-ORPO-Chat": {
|
1586 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
|
1587 |
+
},
|
1588 |
+
},
|
1589 |
+
template="zephyr",
|
1590 |
+
)
|
llama-factory/src/llamafactory/extras/env.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import platform
|
19 |
+
|
20 |
+
import accelerate
|
21 |
+
import datasets
|
22 |
+
import peft
|
23 |
+
import torch
|
24 |
+
import transformers
|
25 |
+
import trl
|
26 |
+
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
27 |
+
|
28 |
+
|
29 |
+
VERSION = "0.8.4.dev0"
|
30 |
+
|
31 |
+
|
32 |
+
def print_env() -> None:
|
33 |
+
info = {
|
34 |
+
"`llamafactory` version": VERSION,
|
35 |
+
"Platform": platform.platform(),
|
36 |
+
"Python version": platform.python_version(),
|
37 |
+
"PyTorch version": torch.__version__,
|
38 |
+
"Transformers version": transformers.__version__,
|
39 |
+
"Datasets version": datasets.__version__,
|
40 |
+
"Accelerate version": accelerate.__version__,
|
41 |
+
"PEFT version": peft.__version__,
|
42 |
+
"TRL version": trl.__version__,
|
43 |
+
}
|
44 |
+
|
45 |
+
if is_torch_cuda_available():
|
46 |
+
info["PyTorch version"] += " (GPU)"
|
47 |
+
info["GPU type"] = torch.cuda.get_device_name()
|
48 |
+
|
49 |
+
if is_torch_npu_available():
|
50 |
+
info["PyTorch version"] += " (NPU)"
|
51 |
+
info["NPU type"] = torch.npu.get_device_name()
|
52 |
+
info["CANN version"] = torch.version.cann
|
53 |
+
|
54 |
+
try:
|
55 |
+
import deepspeed # type: ignore
|
56 |
+
|
57 |
+
info["DeepSpeed version"] = deepspeed.__version__
|
58 |
+
except Exception:
|
59 |
+
pass
|
60 |
+
|
61 |
+
try:
|
62 |
+
import bitsandbytes
|
63 |
+
|
64 |
+
info["Bitsandbytes version"] = bitsandbytes.__version__
|
65 |
+
except Exception:
|
66 |
+
pass
|
67 |
+
|
68 |
+
try:
|
69 |
+
import vllm
|
70 |
+
|
71 |
+
info["vLLM version"] = vllm.__version__
|
72 |
+
except Exception:
|
73 |
+
pass
|
74 |
+
|
75 |
+
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
llama-factory/src/llamafactory/extras/logging.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
from concurrent.futures import ThreadPoolExecutor
|
19 |
+
|
20 |
+
from .constants import RUNNING_LOG
|
21 |
+
|
22 |
+
|
23 |
+
class LoggerHandler(logging.Handler):
|
24 |
+
r"""
|
25 |
+
Logger handler used in Web UI.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, output_dir: str) -> None:
|
29 |
+
super().__init__()
|
30 |
+
formatter = logging.Formatter(
|
31 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
32 |
+
)
|
33 |
+
self.setLevel(logging.INFO)
|
34 |
+
self.setFormatter(formatter)
|
35 |
+
|
36 |
+
os.makedirs(output_dir, exist_ok=True)
|
37 |
+
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
38 |
+
if os.path.exists(self.running_log):
|
39 |
+
os.remove(self.running_log)
|
40 |
+
|
41 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
42 |
+
|
43 |
+
def _write_log(self, log_entry: str) -> None:
|
44 |
+
with open(self.running_log, "a", encoding="utf-8") as f:
|
45 |
+
f.write(log_entry + "\n\n")
|
46 |
+
|
47 |
+
def emit(self, record) -> None:
|
48 |
+
if record.name == "httpx":
|
49 |
+
return
|
50 |
+
|
51 |
+
log_entry = self.format(record)
|
52 |
+
self.thread_pool.submit(self._write_log, log_entry)
|
53 |
+
|
54 |
+
def close(self) -> None:
|
55 |
+
self.thread_pool.shutdown(wait=True)
|
56 |
+
return super().close()
|
57 |
+
|
58 |
+
|
59 |
+
def get_logger(name: str) -> logging.Logger:
|
60 |
+
r"""
|
61 |
+
Gets a standard logger with a stream hander to stdout.
|
62 |
+
"""
|
63 |
+
formatter = logging.Formatter(
|
64 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
65 |
+
)
|
66 |
+
handler = logging.StreamHandler(sys.stdout)
|
67 |
+
handler.setFormatter(formatter)
|
68 |
+
|
69 |
+
logger = logging.getLogger(name)
|
70 |
+
logger.setLevel(logging.INFO)
|
71 |
+
logger.addHandler(handler)
|
72 |
+
|
73 |
+
return logger
|
74 |
+
|
75 |
+
|
76 |
+
def reset_logging() -> None:
|
77 |
+
r"""
|
78 |
+
Removes basic config of root logger. (unused in script)
|
79 |
+
"""
|
80 |
+
root = logging.getLogger()
|
81 |
+
list(map(root.removeHandler, root.handlers))
|
82 |
+
list(map(root.removeFilter, root.filters))
|
llama-factory/src/llamafactory/extras/misc.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's PEFT library.
|
4 |
+
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import gc
|
19 |
+
import os
|
20 |
+
from typing import TYPE_CHECKING, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import transformers.dynamic_module_utils
|
24 |
+
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
25 |
+
from transformers.dynamic_module_utils import get_relative_imports
|
26 |
+
from transformers.utils import (
|
27 |
+
is_torch_bf16_gpu_available,
|
28 |
+
is_torch_cuda_available,
|
29 |
+
is_torch_mps_available,
|
30 |
+
is_torch_npu_available,
|
31 |
+
is_torch_xpu_available,
|
32 |
+
)
|
33 |
+
from transformers.utils.versions import require_version
|
34 |
+
|
35 |
+
from .logging import get_logger
|
36 |
+
|
37 |
+
|
38 |
+
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
39 |
+
try:
|
40 |
+
_is_bf16_available = is_torch_bf16_gpu_available()
|
41 |
+
except Exception:
|
42 |
+
_is_bf16_available = False
|
43 |
+
|
44 |
+
|
45 |
+
if TYPE_CHECKING:
|
46 |
+
from numpy.typing import NDArray
|
47 |
+
|
48 |
+
from ..hparams import ModelArguments
|
49 |
+
|
50 |
+
|
51 |
+
logger = get_logger(__name__)
|
52 |
+
|
53 |
+
|
54 |
+
class AverageMeter:
|
55 |
+
r"""
|
56 |
+
Computes and stores the average and current value.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self):
|
60 |
+
self.reset()
|
61 |
+
|
62 |
+
def reset(self):
|
63 |
+
self.val = 0
|
64 |
+
self.avg = 0
|
65 |
+
self.sum = 0
|
66 |
+
self.count = 0
|
67 |
+
|
68 |
+
def update(self, val, n=1):
|
69 |
+
self.val = val
|
70 |
+
self.sum += val * n
|
71 |
+
self.count += n
|
72 |
+
self.avg = self.sum / self.count
|
73 |
+
|
74 |
+
|
75 |
+
def check_dependencies() -> None:
|
76 |
+
r"""
|
77 |
+
Checks the version of the required packages.
|
78 |
+
"""
|
79 |
+
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
80 |
+
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
81 |
+
else:
|
82 |
+
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
83 |
+
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
|
84 |
+
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
|
85 |
+
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
|
86 |
+
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
87 |
+
|
88 |
+
|
89 |
+
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
90 |
+
r"""
|
91 |
+
Returns the number of trainable parameters and number of all parameters in the model.
|
92 |
+
"""
|
93 |
+
trainable_params, all_param = 0, 0
|
94 |
+
for param in model.parameters():
|
95 |
+
num_params = param.numel()
|
96 |
+
# if using DS Zero 3 and the weights are initialized empty
|
97 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
98 |
+
num_params = param.ds_numel
|
99 |
+
|
100 |
+
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
101 |
+
if param.__class__.__name__ == "Params4bit":
|
102 |
+
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
103 |
+
num_bytes = param.quant_storage.itemsize
|
104 |
+
elif hasattr(param, "element_size"): # for older pytorch version
|
105 |
+
num_bytes = param.element_size()
|
106 |
+
else:
|
107 |
+
num_bytes = 1
|
108 |
+
|
109 |
+
num_params = num_params * 2 * num_bytes
|
110 |
+
|
111 |
+
all_param += num_params
|
112 |
+
if param.requires_grad:
|
113 |
+
trainable_params += num_params
|
114 |
+
|
115 |
+
return trainable_params, all_param
|
116 |
+
|
117 |
+
|
118 |
+
def get_current_device() -> "torch.device":
|
119 |
+
r"""
|
120 |
+
Gets the current available device.
|
121 |
+
"""
|
122 |
+
if is_torch_xpu_available():
|
123 |
+
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
124 |
+
elif is_torch_npu_available():
|
125 |
+
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
126 |
+
elif is_torch_mps_available():
|
127 |
+
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
128 |
+
elif is_torch_cuda_available():
|
129 |
+
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
130 |
+
else:
|
131 |
+
device = "cpu"
|
132 |
+
|
133 |
+
return torch.device(device)
|
134 |
+
|
135 |
+
|
136 |
+
def get_device_count() -> int:
|
137 |
+
r"""
|
138 |
+
Gets the number of available GPU or NPU devices.
|
139 |
+
"""
|
140 |
+
if is_torch_npu_available():
|
141 |
+
return torch.npu.device_count()
|
142 |
+
elif is_torch_cuda_available():
|
143 |
+
return torch.cuda.device_count()
|
144 |
+
else:
|
145 |
+
return 0
|
146 |
+
|
147 |
+
|
148 |
+
def get_logits_processor() -> "LogitsProcessorList":
|
149 |
+
r"""
|
150 |
+
Gets logits processor that removes NaN and Inf logits.
|
151 |
+
"""
|
152 |
+
logits_processor = LogitsProcessorList()
|
153 |
+
logits_processor.append(InfNanRemoveLogitsProcessor())
|
154 |
+
return logits_processor
|
155 |
+
|
156 |
+
|
157 |
+
def has_tokenized_data(path: "os.PathLike") -> bool:
|
158 |
+
r"""
|
159 |
+
Checks if the path has a tokenized dataset.
|
160 |
+
"""
|
161 |
+
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
162 |
+
|
163 |
+
|
164 |
+
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
165 |
+
r"""
|
166 |
+
Infers the optimal dtype according to the model_dtype and device compatibility.
|
167 |
+
"""
|
168 |
+
if _is_bf16_available and model_dtype == torch.bfloat16:
|
169 |
+
return torch.bfloat16
|
170 |
+
elif _is_fp16_available:
|
171 |
+
return torch.float16
|
172 |
+
else:
|
173 |
+
return torch.float32
|
174 |
+
|
175 |
+
|
176 |
+
def is_gpu_or_npu_available() -> bool:
|
177 |
+
r"""
|
178 |
+
Checks if the GPU or NPU is available.
|
179 |
+
"""
|
180 |
+
return is_torch_npu_available() or is_torch_cuda_available()
|
181 |
+
|
182 |
+
|
183 |
+
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
184 |
+
if isinstance(inputs, torch.Tensor):
|
185 |
+
inputs = inputs.cpu()
|
186 |
+
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
187 |
+
inputs = inputs.to(torch.float32)
|
188 |
+
|
189 |
+
inputs = inputs.numpy()
|
190 |
+
|
191 |
+
return inputs
|
192 |
+
|
193 |
+
|
194 |
+
def skip_check_imports() -> None:
|
195 |
+
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
196 |
+
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
197 |
+
|
198 |
+
|
199 |
+
def torch_gc() -> None:
|
200 |
+
r"""
|
201 |
+
Collects GPU or NPU memory.
|
202 |
+
"""
|
203 |
+
gc.collect()
|
204 |
+
if is_torch_xpu_available():
|
205 |
+
torch.xpu.empty_cache()
|
206 |
+
elif is_torch_npu_available():
|
207 |
+
torch.npu.empty_cache()
|
208 |
+
elif is_torch_mps_available():
|
209 |
+
torch.mps.empty_cache()
|
210 |
+
elif is_torch_cuda_available():
|
211 |
+
torch.cuda.empty_cache()
|
212 |
+
|
213 |
+
|
214 |
+
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
215 |
+
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
216 |
+
return model_args.model_name_or_path
|
217 |
+
|
218 |
+
try:
|
219 |
+
from modelscope import snapshot_download
|
220 |
+
|
221 |
+
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
222 |
+
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
223 |
+
except ImportError:
|
224 |
+
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
225 |
+
|
226 |
+
|
227 |
+
def use_modelscope() -> bool:
|
228 |
+
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
llama-factory/src/llamafactory/extras/packages.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import importlib.metadata
|
19 |
+
import importlib.util
|
20 |
+
from functools import lru_cache
|
21 |
+
from typing import TYPE_CHECKING
|
22 |
+
|
23 |
+
from packaging import version
|
24 |
+
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from packaging.version import Version
|
28 |
+
|
29 |
+
|
30 |
+
def _is_package_available(name: str) -> bool:
|
31 |
+
return importlib.util.find_spec(name) is not None
|
32 |
+
|
33 |
+
|
34 |
+
def _get_package_version(name: str) -> "Version":
|
35 |
+
try:
|
36 |
+
return version.parse(importlib.metadata.version(name))
|
37 |
+
except Exception:
|
38 |
+
return version.parse("0.0.0")
|
39 |
+
|
40 |
+
|
41 |
+
def is_fastapi_available():
|
42 |
+
return _is_package_available("fastapi")
|
43 |
+
|
44 |
+
|
45 |
+
def is_galore_available():
|
46 |
+
return _is_package_available("galore_torch")
|
47 |
+
|
48 |
+
|
49 |
+
def is_gradio_available():
|
50 |
+
return _is_package_available("gradio")
|
51 |
+
|
52 |
+
|
53 |
+
def is_matplotlib_available():
|
54 |
+
return _is_package_available("matplotlib")
|
55 |
+
|
56 |
+
|
57 |
+
def is_pillow_available():
|
58 |
+
return _is_package_available("PIL")
|
59 |
+
|
60 |
+
|
61 |
+
def is_requests_available():
|
62 |
+
return _is_package_available("requests")
|
63 |
+
|
64 |
+
|
65 |
+
def is_rouge_available():
|
66 |
+
return _is_package_available("rouge_chinese")
|
67 |
+
|
68 |
+
|
69 |
+
def is_starlette_available():
|
70 |
+
return _is_package_available("sse_starlette")
|
71 |
+
|
72 |
+
|
73 |
+
def is_uvicorn_available():
|
74 |
+
return _is_package_available("uvicorn")
|
75 |
+
|
76 |
+
|
77 |
+
def is_vllm_available():
|
78 |
+
return _is_package_available("vllm")
|
79 |
+
|
80 |
+
|
81 |
+
@lru_cache
|
82 |
+
def is_vllm_version_greater_than_0_5():
|
83 |
+
return _get_package_version("vllm") >= version.parse("0.5.0")
|
84 |
+
|
85 |
+
|
86 |
+
@lru_cache
|
87 |
+
def is_vllm_version_greater_than_0_5_1():
|
88 |
+
return _get_package_version("vllm") >= version.parse("0.5.1")
|
llama-factory/src/llamafactory/extras/ploting.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import math
|
17 |
+
import os
|
18 |
+
from typing import Any, Dict, List
|
19 |
+
|
20 |
+
from transformers.trainer import TRAINER_STATE_NAME
|
21 |
+
|
22 |
+
from .logging import get_logger
|
23 |
+
from .packages import is_matplotlib_available
|
24 |
+
|
25 |
+
|
26 |
+
if is_matplotlib_available():
|
27 |
+
import matplotlib.figure
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
|
30 |
+
|
31 |
+
logger = get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def smooth(scalars: List[float]) -> List[float]:
|
35 |
+
r"""
|
36 |
+
EMA implementation according to TensorBoard.
|
37 |
+
"""
|
38 |
+
if len(scalars) == 0:
|
39 |
+
return []
|
40 |
+
|
41 |
+
last = scalars[0]
|
42 |
+
smoothed = []
|
43 |
+
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
44 |
+
for next_val in scalars:
|
45 |
+
smoothed_val = last * weight + (1 - weight) * next_val
|
46 |
+
smoothed.append(smoothed_val)
|
47 |
+
last = smoothed_val
|
48 |
+
return smoothed
|
49 |
+
|
50 |
+
|
51 |
+
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
52 |
+
r"""
|
53 |
+
Plots loss curves in LlamaBoard.
|
54 |
+
"""
|
55 |
+
plt.close("all")
|
56 |
+
plt.switch_backend("agg")
|
57 |
+
fig = plt.figure()
|
58 |
+
ax = fig.add_subplot(111)
|
59 |
+
steps, losses = [], []
|
60 |
+
for log in trainer_log:
|
61 |
+
if log.get("loss", None):
|
62 |
+
steps.append(log["current_steps"])
|
63 |
+
losses.append(log["loss"])
|
64 |
+
|
65 |
+
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
66 |
+
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
67 |
+
ax.legend()
|
68 |
+
ax.set_xlabel("step")
|
69 |
+
ax.set_ylabel("loss")
|
70 |
+
return fig
|
71 |
+
|
72 |
+
|
73 |
+
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
74 |
+
r"""
|
75 |
+
Plots loss curves and saves the image.
|
76 |
+
"""
|
77 |
+
plt.switch_backend("agg")
|
78 |
+
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
79 |
+
data = json.load(f)
|
80 |
+
|
81 |
+
for key in keys:
|
82 |
+
steps, metrics = [], []
|
83 |
+
for i in range(len(data["log_history"])):
|
84 |
+
if key in data["log_history"][i]:
|
85 |
+
steps.append(data["log_history"][i]["step"])
|
86 |
+
metrics.append(data["log_history"][i][key])
|
87 |
+
|
88 |
+
if len(metrics) == 0:
|
89 |
+
logger.warning(f"No metric {key} to plot.")
|
90 |
+
continue
|
91 |
+
|
92 |
+
plt.figure()
|
93 |
+
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
94 |
+
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
95 |
+
plt.title("training {} of {}".format(key, save_dictionary))
|
96 |
+
plt.xlabel("step")
|
97 |
+
plt.ylabel(key)
|
98 |
+
plt.legend()
|
99 |
+
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
|
100 |
+
plt.savefig(figure_path, format="png", dpi=100)
|
101 |
+
print("Figure saved at:", figure_path)
|
llama-factory/src/llamafactory/hparams/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .data_args import DataArguments
|
16 |
+
from .evaluation_args import EvaluationArguments
|
17 |
+
from .finetuning_args import FinetuningArguments
|
18 |
+
from .generating_args import GeneratingArguments
|
19 |
+
from .model_args import ModelArguments
|
20 |
+
from .parser import get_eval_args, get_infer_args, get_train_args
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
"DataArguments",
|
25 |
+
"EvaluationArguments",
|
26 |
+
"FinetuningArguments",
|
27 |
+
"GeneratingArguments",
|
28 |
+
"ModelArguments",
|
29 |
+
"get_eval_args",
|
30 |
+
"get_infer_args",
|
31 |
+
"get_train_args",
|
32 |
+
]
|
llama-factory/src/llamafactory/hparams/data_args.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import Literal, Optional
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class DataArguments:
|
24 |
+
r"""
|
25 |
+
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
26 |
+
"""
|
27 |
+
|
28 |
+
template: Optional[str] = field(
|
29 |
+
default=None,
|
30 |
+
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
31 |
+
)
|
32 |
+
dataset: Optional[str] = field(
|
33 |
+
default=None,
|
34 |
+
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
35 |
+
)
|
36 |
+
eval_dataset: Optional[str] = field(
|
37 |
+
default=None,
|
38 |
+
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
39 |
+
)
|
40 |
+
dataset_dir: str = field(
|
41 |
+
default="data",
|
42 |
+
metadata={"help": "Path to the folder containing the datasets."},
|
43 |
+
)
|
44 |
+
cutoff_len: int = field(
|
45 |
+
default=1024,
|
46 |
+
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
47 |
+
)
|
48 |
+
train_on_prompt: bool = field(
|
49 |
+
default=False,
|
50 |
+
metadata={"help": "Whether or not to disable the mask on the prompt."},
|
51 |
+
)
|
52 |
+
mask_history: bool = field(
|
53 |
+
default=False,
|
54 |
+
metadata={"help": "Whether or not to mask the history and train on the last turn only."},
|
55 |
+
)
|
56 |
+
streaming: bool = field(
|
57 |
+
default=False,
|
58 |
+
metadata={"help": "Enable dataset streaming."},
|
59 |
+
)
|
60 |
+
buffer_size: int = field(
|
61 |
+
default=16384,
|
62 |
+
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
63 |
+
)
|
64 |
+
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
65 |
+
default="concat",
|
66 |
+
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
67 |
+
)
|
68 |
+
interleave_probs: Optional[str] = field(
|
69 |
+
default=None,
|
70 |
+
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
71 |
+
)
|
72 |
+
overwrite_cache: bool = field(
|
73 |
+
default=False,
|
74 |
+
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
75 |
+
)
|
76 |
+
preprocessing_num_workers: Optional[int] = field(
|
77 |
+
default=None,
|
78 |
+
metadata={"help": "The number of processes to use for the pre-processing."},
|
79 |
+
)
|
80 |
+
max_samples: Optional[int] = field(
|
81 |
+
default=None,
|
82 |
+
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
83 |
+
)
|
84 |
+
eval_num_beams: Optional[int] = field(
|
85 |
+
default=None,
|
86 |
+
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
87 |
+
)
|
88 |
+
ignore_pad_token_for_loss: bool = field(
|
89 |
+
default=True,
|
90 |
+
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
|
91 |
+
)
|
92 |
+
val_size: float = field(
|
93 |
+
default=0.0,
|
94 |
+
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
95 |
+
)
|
96 |
+
packing: Optional[bool] = field(
|
97 |
+
default=None,
|
98 |
+
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
99 |
+
)
|
100 |
+
neat_packing: bool = field(
|
101 |
+
default=False,
|
102 |
+
metadata={"help": "Enable sequence packing without cross-attention."},
|
103 |
+
)
|
104 |
+
tool_format: Optional[str] = field(
|
105 |
+
default=None,
|
106 |
+
metadata={"help": "Tool format to use for constructing function calling examples."},
|
107 |
+
)
|
108 |
+
tokenized_path: Optional[str] = field(
|
109 |
+
default=None,
|
110 |
+
metadata={"help": "Path to save or load the tokenized datasets."},
|
111 |
+
)
|
112 |
+
|
113 |
+
def __post_init__(self):
|
114 |
+
def split_arg(arg):
|
115 |
+
if isinstance(arg, str):
|
116 |
+
return [item.strip() for item in arg.split(",")]
|
117 |
+
return arg
|
118 |
+
|
119 |
+
self.dataset = split_arg(self.dataset)
|
120 |
+
self.eval_dataset = split_arg(self.eval_dataset)
|
121 |
+
|
122 |
+
if self.dataset is None and self.val_size > 1e-6:
|
123 |
+
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
124 |
+
|
125 |
+
if self.eval_dataset is not None and self.val_size > 1e-6:
|
126 |
+
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
127 |
+
|
128 |
+
if self.interleave_probs is not None:
|
129 |
+
if self.mix_strategy == "concat":
|
130 |
+
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
|
131 |
+
|
132 |
+
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
|
133 |
+
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
|
134 |
+
raise ValueError("The length of dataset and interleave probs should be identical.")
|
135 |
+
|
136 |
+
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
|
137 |
+
raise ValueError("The length of eval dataset and interleave probs should be identical.")
|
138 |
+
|
139 |
+
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
140 |
+
raise ValueError("Streaming mode should have an integer val size.")
|
141 |
+
|
142 |
+
if self.streaming and self.max_samples is not None:
|
143 |
+
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
llama-factory/src/llamafactory/hparams/evaluation_args.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
from typing import Literal, Optional
|
18 |
+
|
19 |
+
from datasets import DownloadMode
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class EvaluationArguments:
|
24 |
+
r"""
|
25 |
+
Arguments pertaining to specify the evaluation parameters.
|
26 |
+
"""
|
27 |
+
|
28 |
+
task: str = field(
|
29 |
+
metadata={"help": "Name of the evaluation task."},
|
30 |
+
)
|
31 |
+
task_dir: str = field(
|
32 |
+
default="evaluation",
|
33 |
+
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
34 |
+
)
|
35 |
+
batch_size: int = field(
|
36 |
+
default=4,
|
37 |
+
metadata={"help": "The batch size per GPU for evaluation."},
|
38 |
+
)
|
39 |
+
seed: int = field(
|
40 |
+
default=42,
|
41 |
+
metadata={"help": "Random seed to be used with data loaders."},
|
42 |
+
)
|
43 |
+
lang: Literal["en", "zh"] = field(
|
44 |
+
default="en",
|
45 |
+
metadata={"help": "Language used at evaluation."},
|
46 |
+
)
|
47 |
+
n_shot: int = field(
|
48 |
+
default=5,
|
49 |
+
metadata={"help": "Number of examplars for few-shot learning."},
|
50 |
+
)
|
51 |
+
save_dir: Optional[str] = field(
|
52 |
+
default=None,
|
53 |
+
metadata={"help": "Path to save the evaluation results."},
|
54 |
+
)
|
55 |
+
download_mode: DownloadMode = field(
|
56 |
+
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
57 |
+
metadata={"help": "Download mode used for the evaluation datasets."},
|
58 |
+
)
|
59 |
+
|
60 |
+
def __post_init__(self):
|
61 |
+
if self.save_dir is not None and os.path.exists(self.save_dir):
|
62 |
+
raise ValueError("`save_dir` already exists, use another one.")
|
llama-factory/src/llamafactory/hparams/finetuning_args.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from typing import List, Literal, Optional
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class FreezeArguments:
|
21 |
+
r"""
|
22 |
+
Arguments pertaining to the freeze (partial-parameter) training.
|
23 |
+
"""
|
24 |
+
|
25 |
+
freeze_trainable_layers: int = field(
|
26 |
+
default=2,
|
27 |
+
metadata={
|
28 |
+
"help": (
|
29 |
+
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
|
30 |
+
"Positive numbers mean the last n layers are set as trainable, "
|
31 |
+
"negative numbers mean the first n layers are set as trainable."
|
32 |
+
)
|
33 |
+
},
|
34 |
+
)
|
35 |
+
freeze_trainable_modules: str = field(
|
36 |
+
default="all",
|
37 |
+
metadata={
|
38 |
+
"help": (
|
39 |
+
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
|
40 |
+
"Use commas to separate multiple modules. "
|
41 |
+
"Use `all` to specify all the available modules."
|
42 |
+
)
|
43 |
+
},
|
44 |
+
)
|
45 |
+
freeze_extra_modules: Optional[str] = field(
|
46 |
+
default=None,
|
47 |
+
metadata={
|
48 |
+
"help": (
|
49 |
+
"Name(s) of modules apart from hidden layers to be set as trainable "
|
50 |
+
"for freeze (partial-parameter) fine-tuning. "
|
51 |
+
"Use commas to separate multiple modules."
|
52 |
+
)
|
53 |
+
},
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class LoraArguments:
|
59 |
+
r"""
|
60 |
+
Arguments pertaining to the LoRA training.
|
61 |
+
"""
|
62 |
+
|
63 |
+
additional_target: Optional[str] = field(
|
64 |
+
default=None,
|
65 |
+
metadata={
|
66 |
+
"help": (
|
67 |
+
"Name(s) of modules apart from LoRA layers to be set as trainable "
|
68 |
+
"and saved in the final checkpoint. "
|
69 |
+
"Use commas to separate multiple modules."
|
70 |
+
)
|
71 |
+
},
|
72 |
+
)
|
73 |
+
lora_alpha: Optional[int] = field(
|
74 |
+
default=None,
|
75 |
+
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
76 |
+
)
|
77 |
+
lora_dropout: float = field(
|
78 |
+
default=0.0,
|
79 |
+
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
80 |
+
)
|
81 |
+
lora_rank: int = field(
|
82 |
+
default=8,
|
83 |
+
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
84 |
+
)
|
85 |
+
lora_target: str = field(
|
86 |
+
default="all",
|
87 |
+
metadata={
|
88 |
+
"help": (
|
89 |
+
"Name(s) of target modules to apply LoRA. "
|
90 |
+
"Use commas to separate multiple modules. "
|
91 |
+
"Use `all` to specify all the linear modules."
|
92 |
+
)
|
93 |
+
},
|
94 |
+
)
|
95 |
+
loraplus_lr_ratio: Optional[float] = field(
|
96 |
+
default=None,
|
97 |
+
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
98 |
+
)
|
99 |
+
loraplus_lr_embedding: float = field(
|
100 |
+
default=1e-6,
|
101 |
+
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
102 |
+
)
|
103 |
+
use_rslora: bool = field(
|
104 |
+
default=False,
|
105 |
+
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
106 |
+
)
|
107 |
+
use_dora: bool = field(
|
108 |
+
default=False,
|
109 |
+
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
110 |
+
)
|
111 |
+
pissa_init: bool = field(
|
112 |
+
default=False,
|
113 |
+
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
|
114 |
+
)
|
115 |
+
pissa_iter: int = field(
|
116 |
+
default=16,
|
117 |
+
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
|
118 |
+
)
|
119 |
+
pissa_convert: bool = field(
|
120 |
+
default=False,
|
121 |
+
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
|
122 |
+
)
|
123 |
+
create_new_adapter: bool = field(
|
124 |
+
default=False,
|
125 |
+
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
@dataclass
|
130 |
+
class RLHFArguments:
|
131 |
+
r"""
|
132 |
+
Arguments pertaining to the PPO, DPO and KTO training.
|
133 |
+
"""
|
134 |
+
|
135 |
+
pref_beta: float = field(
|
136 |
+
default=0.1,
|
137 |
+
metadata={"help": "The beta parameter in the preference loss."},
|
138 |
+
)
|
139 |
+
pref_ftx: float = field(
|
140 |
+
default=0.0,
|
141 |
+
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
142 |
+
)
|
143 |
+
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
144 |
+
default="sigmoid",
|
145 |
+
metadata={"help": "The type of DPO loss to use."},
|
146 |
+
)
|
147 |
+
dpo_label_smoothing: float = field(
|
148 |
+
default=0.0,
|
149 |
+
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
|
150 |
+
)
|
151 |
+
kto_chosen_weight: float = field(
|
152 |
+
default=1.0,
|
153 |
+
metadata={"help": "The weight factor of the desirable losses in KTO training."},
|
154 |
+
)
|
155 |
+
kto_rejected_weight: float = field(
|
156 |
+
default=1.0,
|
157 |
+
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
|
158 |
+
)
|
159 |
+
simpo_gamma: float = field(
|
160 |
+
default=0.5,
|
161 |
+
metadata={"help": "The target reward margin term in SimPO loss."},
|
162 |
+
)
|
163 |
+
ppo_buffer_size: int = field(
|
164 |
+
default=1,
|
165 |
+
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
166 |
+
)
|
167 |
+
ppo_epochs: int = field(
|
168 |
+
default=4,
|
169 |
+
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
170 |
+
)
|
171 |
+
ppo_score_norm: bool = field(
|
172 |
+
default=False,
|
173 |
+
metadata={"help": "Use score normalization in PPO training."},
|
174 |
+
)
|
175 |
+
ppo_target: float = field(
|
176 |
+
default=6.0,
|
177 |
+
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
178 |
+
)
|
179 |
+
ppo_whiten_rewards: bool = field(
|
180 |
+
default=False,
|
181 |
+
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
182 |
+
)
|
183 |
+
ref_model: Optional[str] = field(
|
184 |
+
default=None,
|
185 |
+
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
186 |
+
)
|
187 |
+
ref_model_adapters: Optional[str] = field(
|
188 |
+
default=None,
|
189 |
+
metadata={"help": "Path to the adapters of the reference model."},
|
190 |
+
)
|
191 |
+
ref_model_quantization_bit: Optional[int] = field(
|
192 |
+
default=None,
|
193 |
+
metadata={"help": "The number of bits to quantize the reference model."},
|
194 |
+
)
|
195 |
+
reward_model: Optional[str] = field(
|
196 |
+
default=None,
|
197 |
+
metadata={"help": "Path to the reward model used for the PPO training."},
|
198 |
+
)
|
199 |
+
reward_model_adapters: Optional[str] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={"help": "Path to the adapters of the reward model."},
|
202 |
+
)
|
203 |
+
reward_model_quantization_bit: Optional[int] = field(
|
204 |
+
default=None,
|
205 |
+
metadata={"help": "The number of bits to quantize the reward model."},
|
206 |
+
)
|
207 |
+
reward_model_type: Literal["lora", "full", "api"] = field(
|
208 |
+
default="lora",
|
209 |
+
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
@dataclass
|
214 |
+
class GaloreArguments:
|
215 |
+
r"""
|
216 |
+
Arguments pertaining to the GaLore algorithm.
|
217 |
+
"""
|
218 |
+
|
219 |
+
use_galore: bool = field(
|
220 |
+
default=False,
|
221 |
+
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
|
222 |
+
)
|
223 |
+
galore_target: str = field(
|
224 |
+
default="all",
|
225 |
+
metadata={
|
226 |
+
"help": (
|
227 |
+
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
|
228 |
+
"Use `all` to specify all the linear modules."
|
229 |
+
)
|
230 |
+
},
|
231 |
+
)
|
232 |
+
galore_rank: int = field(
|
233 |
+
default=16,
|
234 |
+
metadata={"help": "The rank of GaLore gradients."},
|
235 |
+
)
|
236 |
+
galore_update_interval: int = field(
|
237 |
+
default=200,
|
238 |
+
metadata={"help": "Number of steps to update the GaLore projection."},
|
239 |
+
)
|
240 |
+
galore_scale: float = field(
|
241 |
+
default=0.25,
|
242 |
+
metadata={"help": "GaLore scaling coefficient."},
|
243 |
+
)
|
244 |
+
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
245 |
+
default="std",
|
246 |
+
metadata={"help": "Type of GaLore projection."},
|
247 |
+
)
|
248 |
+
galore_layerwise: bool = field(
|
249 |
+
default=False,
|
250 |
+
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
@dataclass
|
255 |
+
class BAdamArgument:
|
256 |
+
r"""
|
257 |
+
Arguments pertaining to the BAdam optimizer.
|
258 |
+
"""
|
259 |
+
|
260 |
+
use_badam: bool = field(
|
261 |
+
default=False,
|
262 |
+
metadata={"help": "Whether or not to use the BAdam optimizer."},
|
263 |
+
)
|
264 |
+
badam_mode: Literal["layer", "ratio"] = field(
|
265 |
+
default="layer",
|
266 |
+
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
267 |
+
)
|
268 |
+
badam_start_block: Optional[int] = field(
|
269 |
+
default=None,
|
270 |
+
metadata={"help": "The starting block index for layer-wise BAdam."},
|
271 |
+
)
|
272 |
+
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
273 |
+
default="ascending",
|
274 |
+
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
275 |
+
)
|
276 |
+
badam_switch_interval: Optional[int] = field(
|
277 |
+
default=50,
|
278 |
+
metadata={
|
279 |
+
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
280 |
+
},
|
281 |
+
)
|
282 |
+
badam_update_ratio: float = field(
|
283 |
+
default=0.05,
|
284 |
+
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
285 |
+
)
|
286 |
+
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
287 |
+
default="adjacent",
|
288 |
+
metadata={
|
289 |
+
"help": (
|
290 |
+
"The mode of the mask for BAdam optimizer. "
|
291 |
+
"`adjacent` means that the trainable parameters are adjacent to each other, "
|
292 |
+
"`scatter` means that trainable parameters are randomly choosed from the weight."
|
293 |
+
)
|
294 |
+
},
|
295 |
+
)
|
296 |
+
badam_verbose: int = field(
|
297 |
+
default=0,
|
298 |
+
metadata={
|
299 |
+
"help": (
|
300 |
+
"The verbosity level of BAdam optimizer. "
|
301 |
+
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
|
302 |
+
)
|
303 |
+
},
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
@dataclass
|
308 |
+
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
309 |
+
r"""
|
310 |
+
Arguments pertaining to which techniques we are going to fine-tuning with.
|
311 |
+
"""
|
312 |
+
|
313 |
+
pure_bf16: bool = field(
|
314 |
+
default=False,
|
315 |
+
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
316 |
+
)
|
317 |
+
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
|
318 |
+
default="sft",
|
319 |
+
metadata={"help": "Which stage will be performed in training."},
|
320 |
+
)
|
321 |
+
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
322 |
+
default="lora",
|
323 |
+
metadata={"help": "Which fine-tuning method to use."},
|
324 |
+
)
|
325 |
+
use_llama_pro: bool = field(
|
326 |
+
default=False,
|
327 |
+
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
328 |
+
)
|
329 |
+
freeze_vision_tower: bool = field(
|
330 |
+
default=True,
|
331 |
+
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
332 |
+
)
|
333 |
+
train_mm_proj_only: bool = field(
|
334 |
+
default=False,
|
335 |
+
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
336 |
+
)
|
337 |
+
compute_accuracy: bool = field(
|
338 |
+
default=False,
|
339 |
+
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
|
340 |
+
)
|
341 |
+
plot_loss: bool = field(
|
342 |
+
default=False,
|
343 |
+
metadata={"help": "Whether or not to save the training loss curves."},
|
344 |
+
)
|
345 |
+
|
346 |
+
def __post_init__(self):
|
347 |
+
def split_arg(arg):
|
348 |
+
if isinstance(arg, str):
|
349 |
+
return [item.strip() for item in arg.split(",")]
|
350 |
+
return arg
|
351 |
+
|
352 |
+
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
353 |
+
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
354 |
+
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
355 |
+
self.lora_target: List[str] = split_arg(self.lora_target)
|
356 |
+
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
357 |
+
self.galore_target: List[str] = split_arg(self.galore_target)
|
358 |
+
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
359 |
+
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
360 |
+
|
361 |
+
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
362 |
+
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
363 |
+
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
364 |
+
|
365 |
+
if self.stage == "ppo" and self.reward_model is None:
|
366 |
+
raise ValueError("`reward_model` is necessary for PPO training.")
|
367 |
+
|
368 |
+
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
369 |
+
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
370 |
+
|
371 |
+
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
372 |
+
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
373 |
+
|
374 |
+
if self.use_llama_pro and self.finetuning_type == "full":
|
375 |
+
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
376 |
+
|
377 |
+
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
378 |
+
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
379 |
+
|
380 |
+
if self.use_galore and self.use_badam:
|
381 |
+
raise ValueError("Cannot use GaLore with BAdam together.")
|
382 |
+
|
383 |
+
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
384 |
+
raise ValueError("Cannot use PiSSA for current training stage.")
|
385 |
+
|
386 |
+
if self.train_mm_proj_only and self.finetuning_type != "full":
|
387 |
+
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
388 |
+
|
389 |
+
if self.finetuning_type != "lora":
|
390 |
+
if self.loraplus_lr_ratio is not None:
|
391 |
+
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
392 |
+
|
393 |
+
if self.use_rslora:
|
394 |
+
raise ValueError("`use_rslora` is only valid for LoRA training.")
|
395 |
+
|
396 |
+
if self.use_dora:
|
397 |
+
raise ValueError("`use_dora` is only valid for LoRA training.")
|
398 |
+
|
399 |
+
if self.pissa_init:
|
400 |
+
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
llama-factory/src/llamafactory/hparams/generating_args.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import asdict, dataclass, field
|
16 |
+
from typing import Any, Dict, Optional
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class GeneratingArguments:
|
21 |
+
r"""
|
22 |
+
Arguments pertaining to specify the decoding parameters.
|
23 |
+
"""
|
24 |
+
|
25 |
+
do_sample: bool = field(
|
26 |
+
default=True,
|
27 |
+
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
28 |
+
)
|
29 |
+
temperature: float = field(
|
30 |
+
default=0.95,
|
31 |
+
metadata={"help": "The value used to modulate the next token probabilities."},
|
32 |
+
)
|
33 |
+
top_p: float = field(
|
34 |
+
default=0.7,
|
35 |
+
metadata={
|
36 |
+
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
37 |
+
},
|
38 |
+
)
|
39 |
+
top_k: int = field(
|
40 |
+
default=50,
|
41 |
+
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
42 |
+
)
|
43 |
+
num_beams: int = field(
|
44 |
+
default=1,
|
45 |
+
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
46 |
+
)
|
47 |
+
max_length: int = field(
|
48 |
+
default=1024,
|
49 |
+
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
50 |
+
)
|
51 |
+
max_new_tokens: int = field(
|
52 |
+
default=1024,
|
53 |
+
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
54 |
+
)
|
55 |
+
repetition_penalty: float = field(
|
56 |
+
default=1.0,
|
57 |
+
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
58 |
+
)
|
59 |
+
length_penalty: float = field(
|
60 |
+
default=1.0,
|
61 |
+
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
62 |
+
)
|
63 |
+
default_system: Optional[str] = field(
|
64 |
+
default=None,
|
65 |
+
metadata={"help": "Default system message to use in chat completion."},
|
66 |
+
)
|
67 |
+
|
68 |
+
def to_dict(self) -> Dict[str, Any]:
|
69 |
+
args = asdict(self)
|
70 |
+
if args.get("max_new_tokens", -1) > 0:
|
71 |
+
args.pop("max_length", None)
|
72 |
+
else:
|
73 |
+
args.pop("max_new_tokens", None)
|
74 |
+
return args
|