Student0809 commited on
Commit
3544698
·
verified ·
1 Parent(s): e7a862c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. ms-swift/.ipynb_checkpoints/README-checkpoint.md +423 -0
  3. ms-swift/.ipynb_checkpoints/README_CN-checkpoint.md +413 -0
  4. ms-swift/.ipynb_checkpoints/dataset-checkpoint.json +60 -0
  5. ms-swift/.ipynb_checkpoints/dataset_overlap5s716_gemini-checkpoint.json +0 -0
  6. ms-swift/.ipynb_checkpoints/gen_data-checkpoint.py +154 -0
  7. ms-swift/.ipynb_checkpoints/overlap5s716_gemini-checkpoint.json +0 -0
  8. ms-swift/.ipynb_checkpoints/setup-checkpoint.py +165 -0
  9. ms-swift/.ipynb_checkpoints/test-checkpoint.sh +6 -0
  10. ms-swift/.ipynb_checkpoints/train-checkpoint.sh +42 -0
  11. ms-swift/asset/banner.png +3 -0
  12. ms-swift/docs/resources/dpo_data.png +3 -0
  13. ms-swift/docs/resources/grpo_clevr_count.png +3 -0
  14. ms-swift/docs/resources/grpo_code.png +3 -0
  15. ms-swift/docs/resources/grpo_countdown.png +3 -0
  16. ms-swift/docs/resources/grpo_countdown_1.png +3 -0
  17. ms-swift/docs/resources/grpo_geoqa.png +3 -0
  18. ms-swift/docs/resources/grpo_openr1_multimodal.png +3 -0
  19. ms-swift/docs/resources/kto_data.png +3 -0
  20. ms-swift/docs/resources/web-ui-en.jpg +3 -0
  21. ms-swift/docs/resources/web-ui.jpg +3 -0
  22. ms-swift/silence_overlaps.zip +3 -0
  23. ms-swift/silence_overlaps/.ipynb_checkpoints/clean_wrong-checkpoint.py +73 -0
  24. ms-swift/silence_overlaps/.ipynb_checkpoints/cleaned_transcriptions-checkpoint.json +0 -0
  25. ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript-checkpoint.json +0 -0
  26. ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript2-checkpoint.json +1 -0
  27. ms-swift/silence_overlaps/700/merge_and_shuffle_json.py +61 -0
  28. ms-swift/silence_overlaps/700/split_train_test.py +36 -0
  29. ms-swift/silence_overlaps/700/train/overlap5s_isoverlap_train.json +0 -0
  30. ms-swift/silence_overlaps/700/train/overlap5s_speaker_segments_train.json +0 -0
  31. ms-swift/silence_overlaps/clean_wrong.py +73 -0
  32. ms-swift/silence_overlaps/cleaned_transcriptions.json +0 -0
  33. ms-swift/silence_overlaps/overlap5s_isoverlap.json +0 -0
  34. ms-swift/silence_overlaps/overlap5s_speaker_segments.json +0 -0
  35. ms-swift/silence_overlaps/overlap5s_transcriptions.json +0 -0
  36. ms-swift/silence_overlaps/silence_isoverlaps.json +0 -0
  37. ms-swift/silence_overlaps/silence_issilence.json +0 -0
  38. ms-swift/silence_overlaps/transcriptions.json +0 -0
  39. ms-swift/swift/ui/llm_train/utils.py +37 -0
  40. ms-swift/swift/utils/__pycache__/logger.cpython-310.pyc +0 -0
  41. ms-swift/swift/utils/__pycache__/torch_utils.cpython-310.pyc +0 -0
  42. ms-swift/swift/utils/__pycache__/utils.cpython-310.pyc +0 -0
  43. ms-swift/swift/utils/utils.py +323 -0
  44. ms-swift/tests/__init__.py +0 -0
  45. ms-swift/tests/app/test_app.py +25 -0
  46. ms-swift/tests/llm/data/multi_modal_1.jsonl +3 -0
  47. ms-swift/tests/models/test_flash_attn.py +8 -0
  48. ms-swift/tests/test_align/test_rlhf_loss.py +0 -0
  49. ms-swift/tests/test_align/test_template/test_agent.py +325 -0
  50. ms-swift/tests/test_align/test_template/test_audio.py +76 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ms-swift/docs/resources/web-ui.jpg filter=lfs diff=lfs merge=lfs -text
37
+ ms-swift/docs/resources/grpo_code.png filter=lfs diff=lfs merge=lfs -text
38
+ ms-swift/docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
39
+ ms-swift/asset/banner.png filter=lfs diff=lfs merge=lfs -text
40
+ ms-swift/docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text
41
+ ms-swift/docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
42
+ ms-swift/docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
43
+ ms-swift/docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
44
+ ms-swift/docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
45
+ ms-swift/docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
46
+ ms-swift/docs/resources/grpo_countdown.png filter=lfs diff=lfs merge=lfs -text
ms-swift/.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
2
+
3
+ <p align="center">
4
+ <br>
5
+ <img src="asset/banner.png"/>
6
+ <br>
7
+ <p>
8
+ <p align="center">
9
+ <a href="https://modelscope.cn/home">ModelScope Community Website</a>
10
+ <br>
11
+ <a href="README_CN.md">中文</a> &nbsp | &nbsp English &nbsp
12
+ </p>
13
+
14
+ <p align="center">
15
+ <img src="https://img.shields.io/badge/python-3.10-5be.svg">
16
+ <img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
17
+ <a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
18
+ <a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
19
+ <a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
20
+ <a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
21
+ <a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
22
+ </p>
23
+
24
+ <p align="center">
25
+ <a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
26
+ </p>
27
+
28
+ <p align="center">
29
+ <a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
30
+ </p>
31
+
32
+ ## 📖 Table of Contents
33
+ - [Groups](#-Groups)
34
+ - [Introduction](#-introduction)
35
+ - [News](#-news)
36
+ - [Installation](#%EF%B8%8F-installation)
37
+ - [Quick Start](#-quick-Start)
38
+ - [Usage](#-Usage)
39
+ - [License](#-License)
40
+ - [Citation](#-citation)
41
+
42
+
43
+ ## ☎ Groups
44
+
45
+ You can contact us and communicate with us by adding our group:
46
+
47
+
48
+ [Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
49
+ :-------------------------:|:-------------------------:
50
+ <img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
51
+
52
+
53
+ ## 📝 Introduction
54
+ 🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
55
+
56
+ 🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
57
+
58
+ **Why choose ms-swift?**
59
+
60
+ - 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
61
+ - **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
62
+ - **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
63
+ - 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
64
+ - **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
65
+ - **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
66
+ - **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
67
+ - 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
68
+ - **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
69
+ - **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
70
+ - 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
71
+ - **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
72
+ - **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
73
+ - **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
74
+
75
+
76
+ ## 🎉 News
77
+ - 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
78
+ - 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
79
+ - 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
80
+ - 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
81
+ - 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
82
+ - 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
83
+ - 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
84
+ - 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
85
+ - 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
86
+ - 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
87
+ <details><summary>More</summary>
88
+
89
+ - 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
90
+ - 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
91
+ - 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
92
+ - 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
93
+ - 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
94
+ </details>
95
+
96
+ ## 🛠️ Installation
97
+ To install using pip:
98
+ ```shell
99
+ pip install ms-swift -U
100
+ ```
101
+
102
+ To install from source:
103
+ ```shell
104
+ # pip install git+https://github.com/modelscope/ms-swift.git
105
+
106
+ git clone https://github.com/modelscope/ms-swift.git
107
+ cd ms-swift
108
+ pip install -e .
109
+ ```
110
+
111
+ Running Environment:
112
+
113
+ | | Range | Recommended | Notes |
114
+ | ------------ |--------------| ----------- | ----------------------------------------- |
115
+ | python | >=3.9 | 3.10 | |
116
+ | cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
117
+ | torch | >=2.0 | | |
118
+ | transformers | >=4.33 | 4.51 | |
119
+ | modelscope | >=1.23 | | |
120
+ | peft | >=0.11,<0.16 | ||
121
+ | trl | >=0.13,<0.18 | 0.17 |RLHF|
122
+ | deepspeed | >=0.14 | 0.14.5 | Training |
123
+ | vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
124
+ | lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
125
+ | evalscope | >=0.11 | | Evaluation |
126
+
127
+ For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
128
+
129
+
130
+ ## 🚀 Quick Start
131
+
132
+ 10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
133
+
134
+ ### Command Line Interface
135
+
136
+ ```shell
137
+ # 22GB
138
+ CUDA_VISIBLE_DEVICES=0 \
139
+ swift sft \
140
+ --model Qwen/Qwen2.5-7B-Instruct \
141
+ --train_type lora \
142
+ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
143
+ 'AI-ModelScope/alpaca-gpt4-data-en#500' \
144
+ 'swift/self-cognition#500' \
145
+ --torch_dtype bfloat16 \
146
+ --num_train_epochs 1 \
147
+ --per_device_train_batch_size 1 \
148
+ --per_device_eval_batch_size 1 \
149
+ --learning_rate 1e-4 \
150
+ --lora_rank 8 \
151
+ --lora_alpha 32 \
152
+ --target_modules all-linear \
153
+ --gradient_accumulation_steps 16 \
154
+ --eval_steps 50 \
155
+ --save_steps 50 \
156
+ --save_total_limit 2 \
157
+ --logging_steps 5 \
158
+ --max_length 2048 \
159
+ --output_dir output \
160
+ --system 'You are a helpful assistant.' \
161
+ --warmup_ratio 0.05 \
162
+ --dataloader_num_workers 4 \
163
+ --model_author swift \
164
+ --model_name swift-robot
165
+ ```
166
+
167
+ Tips:
168
+
169
+ - If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
170
+ - The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
171
+ - To train with a different model, simply modify `--model <model_id/model_path>`.
172
+ - By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
173
+
174
+ After training is complete, use the following command to infer with the trained weights:
175
+
176
+ - Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
177
+
178
+ ```shell
179
+ # Using an interactive command line for inference.
180
+ CUDA_VISIBLE_DEVICES=0 \
181
+ swift infer \
182
+ --adapters output/vx-xxx/checkpoint-xxx \
183
+ --stream true \
184
+ --temperature 0 \
185
+ --max_new_tokens 2048
186
+
187
+ # merge-lora and use vLLM for inference acceleration
188
+ CUDA_VISIBLE_DEVICES=0 \
189
+ swift infer \
190
+ --adapters output/vx-xxx/checkpoint-xxx \
191
+ --stream true \
192
+ --merge_lora true \
193
+ --infer_backend vllm \
194
+ --max_model_len 8192 \
195
+ --temperature 0 \
196
+ --max_new_tokens 2048
197
+ ```
198
+
199
+ Finally, use the following command to push the model to ModelScope:
200
+
201
+ ```shell
202
+ CUDA_VISIBLE_DEVICES=0 \
203
+ swift export \
204
+ --adapters output/vx-xxx/checkpoint-xxx \
205
+ --push_to_hub true \
206
+ --hub_model_id '<your-model-id>' \
207
+ --hub_token '<your-sdk-token>' \
208
+ --use_hf false
209
+ ```
210
+
211
+
212
+ ### Web-UI
213
+ The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
214
+
215
+ ```shell
216
+ SWIFT_UI_LANG=en swift web-ui
217
+ ```
218
+
219
+ ![image.png](./docs/resources/web-ui-en.jpg)
220
+
221
+ ### Using Python
222
+
223
+ ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
224
+
225
+ Training:
226
+
227
+ ```python
228
+ # Retrieve the model and template, and add a trainable LoRA module
229
+ model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
230
+ template = get_template(model.model_meta.template, tokenizer, ...)
231
+ model = Swift.prepare_model(model, lora_config)
232
+
233
+ # Download and load the dataset, and encode the text into tokens
234
+ train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
235
+ train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
236
+ val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
237
+
238
+ # Train the model
239
+ trainer = Seq2SeqTrainer(
240
+ model=model,
241
+ args=training_args,
242
+ data_collator=template.data_collator,
243
+ train_dataset=train_dataset,
244
+ eval_dataset=val_dataset,
245
+ template=template,
246
+ )
247
+ trainer.train()
248
+ ```
249
+ Inference:
250
+
251
+ ```python
252
+ # Perform inference using the native PyTorch engine
253
+ engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
254
+ infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
255
+ request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
256
+
257
+ resp_list = engine.infer([infer_request], request_config)
258
+ print(f'response: {resp_list[0].choices[0].message.content}')
259
+ ```
260
+
261
+ ## ✨ Usage
262
+ Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
263
+
264
+ - If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
265
+ - By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
266
+
267
+ | Useful Links |
268
+ | ------ |
269
+ | [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
270
+ | [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
271
+ | [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
272
+ | [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
273
+
274
+ ### Training
275
+
276
+ Supported Training Methods:
277
+
278
+ | Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
279
+ |------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
280
+ | Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
281
+ | Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
282
+ | DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
283
+ | GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
284
+ | Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
285
+ | PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
286
+ | KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
287
+ | CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
288
+ | SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
289
+ | ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
290
+ | Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
291
+ | Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
292
+
293
+
294
+
295
+ Pre-training:
296
+ ```shell
297
+ # 8*A100
298
+ NPROC_PER_NODE=8 \
299
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
300
+ swift pt \
301
+ --model Qwen/Qwen2.5-7B \
302
+ --dataset swift/chinese-c4 \
303
+ --streaming true \
304
+ --train_type full \
305
+ --deepspeed zero2 \
306
+ --output_dir output \
307
+ --max_steps 10000 \
308
+ ...
309
+ ```
310
+
311
+ Fine-tuning:
312
+ ```shell
313
+ CUDA_VISIBLE_DEVICES=0 swift sft \
314
+ --model Qwen/Qwen2.5-7B-Instruct \
315
+ --dataset AI-ModelScope/alpaca-gpt4-data-en \
316
+ --train_type lora \
317
+ --output_dir output \
318
+ ...
319
+ ```
320
+
321
+ RLHF:
322
+ ```shell
323
+ CUDA_VISIBLE_DEVICES=0 swift rlhf \
324
+ --rlhf_type dpo \
325
+ --model Qwen/Qwen2.5-7B-Instruct \
326
+ --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
327
+ --train_type lora \
328
+ --output_dir output \
329
+ ...
330
+ ```
331
+
332
+
333
+ ### Inference
334
+ ```shell
335
+ CUDA_VISIBLE_DEVICES=0 swift infer \
336
+ --model Qwen/Qwen2.5-7B-Instruct \
337
+ --stream true \
338
+ --infer_backend pt \
339
+ --max_new_tokens 2048
340
+
341
+ # LoRA
342
+ CUDA_VISIBLE_DEVICES=0 swift infer \
343
+ --model Qwen/Qwen2.5-7B-Instruct \
344
+ --adapters swift/test_lora \
345
+ --stream true \
346
+ --infer_backend pt \
347
+ --temperature 0 \
348
+ --max_new_tokens 2048
349
+ ```
350
+
351
+ ### Interface Inference
352
+ ```shell
353
+ CUDA_VISIBLE_DEVICES=0 swift app \
354
+ --model Qwen/Qwen2.5-7B-Instruct \
355
+ --stream true \
356
+ --infer_backend pt \
357
+ --max_new_tokens 2048
358
+ ```
359
+
360
+ ### Deployment
361
+ ```shell
362
+ CUDA_VISIBLE_DEVICES=0 swift deploy \
363
+ --model Qwen/Qwen2.5-7B-Instruct \
364
+ --infer_backend vllm
365
+ ```
366
+
367
+ ### Sampling
368
+ ```shell
369
+ CUDA_VISIBLE_DEVICES=0 swift sample \
370
+ --model LLM-Research/Meta-Llama-3.1-8B-Instruct \
371
+ --sampler_engine pt \
372
+ --num_return_sequences 5 \
373
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
374
+ ```
375
+
376
+ ### Evaluation
377
+ ```shell
378
+ CUDA_VISIBLE_DEVICES=0 swift eval \
379
+ --model Qwen/Qwen2.5-7B-Instruct \
380
+ --infer_backend lmdeploy \
381
+ --eval_backend OpenCompass \
382
+ --eval_dataset ARC_c
383
+ ```
384
+
385
+ ### Quantization
386
+ ```shell
387
+ CUDA_VISIBLE_DEVICES=0 swift export \
388
+ --model Qwen/Qwen2.5-7B-Instruct \
389
+ --quant_bits 4 --quant_method awq \
390
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
391
+ --output_dir Qwen2.5-7B-Instruct-AWQ
392
+ ```
393
+
394
+ ### Push Model
395
+ ```shell
396
+ swift export \
397
+ --model <model-path> \
398
+ --push_to_hub true \
399
+ --hub_model_id '<model-id>' \
400
+ --hub_token '<sdk-token>'
401
+ ```
402
+
403
+ ## 🏛 License
404
+
405
+ This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
406
+
407
+ ## 📎 Citation
408
+
409
+ ```bibtex
410
+ @misc{zhao2024swiftascalablelightweightinfrastructure,
411
+ title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
412
+ author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
413
+ year={2024},
414
+ eprint={2408.05517},
415
+ archivePrefix={arXiv},
416
+ primaryClass={cs.CL},
417
+ url={https://arxiv.org/abs/2408.05517},
418
+ }
419
+ ```
420
+
421
+ ## Star History
422
+
423
+ [![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date)
ms-swift/.ipynb_checkpoints/README_CN-checkpoint.md ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
2
+
3
+ <p align="center">
4
+ <br>
5
+ <img src="asset/banner.png"/>
6
+ <br>
7
+ <p>
8
+ <p align="center">
9
+ <a href="https://modelscope.cn/home">魔搭社区官网</a>
10
+ <br>
11
+ 中文&nbsp | &nbsp<a href="README.md">English</a>&nbsp
12
+ </p>
13
+
14
+
15
+ <p align="center">
16
+ <img src="https://img.shields.io/badge/python-3.10-5be.svg">
17
+ <img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
18
+ <a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
19
+ <a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
20
+ <a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
21
+ <a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
22
+ <a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
23
+ </p>
24
+
25
+ <p align="center">
26
+ <a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
27
+ </p>
28
+
29
+ <p align="center">
30
+ <a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
31
+ </p>
32
+
33
+ ## 📖 目录
34
+ - [用户群](#-用户群)
35
+ - [简介](#-简介)
36
+ - [新闻](#-新闻)
37
+ - [安装](#%EF%B8%8F-安装)
38
+ - [快速开始](#-快速开始)
39
+ - [如何使用](#-如何使用)
40
+ - [License](#-license)
41
+ - [引用](#-引用)
42
+
43
+ ## ☎ 用户群
44
+
45
+ 请扫描下面的二维码来加入我们的交流群:
46
+
47
+ [Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
48
+ :-------------------------:|:-------------------------:
49
+ <img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
50
+
51
+ ## 📝 简介
52
+ 🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
53
+
54
+ 🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
55
+
56
+ **为什么选择ms-swift?**
57
+ - 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
58
+ - **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
59
+ - **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
60
+ - 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
61
+ - **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
62
+ - **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
63
+ - **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
64
+ - 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
65
+ - **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
66
+ - **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
67
+ - 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
68
+ - **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
69
+ - **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
70
+ - **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
71
+
72
+ ## 🎉 新闻
73
+ - 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
74
+ - 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
75
+ - 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
76
+ - 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
77
+ - 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
78
+ - 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
79
+ - 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
80
+ - 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
81
+ - 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
82
+ - 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
83
+ <details><summary>更多</summary>
84
+
85
+ - 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
86
+ - 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
87
+ - 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
88
+ - 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
89
+ - 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
90
+ </details>
91
+
92
+ ## 🛠️ 安装
93
+ 使用pip进行安装:
94
+ ```shell
95
+ pip install ms-swift -U
96
+ ```
97
+
98
+ 从源代码安装:
99
+ ```shell
100
+ # pip install git+https://github.com/modelscope/ms-swift.git
101
+
102
+ git clone https://github.com/modelscope/ms-swift.git
103
+ cd ms-swift
104
+ pip install -e .
105
+ ```
106
+
107
+ 运行环境:
108
+
109
+ | | 范围 | 推荐 | 备注 |
110
+ | ------ |--------------| ---- | --|
111
+ | python | >=3.9 | 3.10 ||
112
+ | cuda | | cuda12 |使用cpu、npu、mps则无需安装|
113
+ | torch | >=2.0 | ||
114
+ | transformers | >=4.33 | 4.51 ||
115
+ | modelscope | >=1.23 | ||
116
+ | peft | >=0.11,<0.16 | ||
117
+ | trl | >=0.13,<0.18 | 0.17 |RLHF|
118
+ | deepspeed | >=0.14 | 0.14.5 |训练|
119
+ | vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
120
+ | lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
121
+ | evalscope | >=0.11 | |评测|
122
+
123
+ 更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
124
+
125
+
126
+ ## 🚀 快速开始
127
+
128
+ **10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
129
+
130
+ ### 命令行
131
+ ```shell
132
+ # 22GB
133
+ CUDA_VISIBLE_DEVICES=0 \
134
+ swift sft \
135
+ --model Qwen/Qwen2.5-7B-Instruct \
136
+ --train_type lora \
137
+ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
138
+ 'AI-ModelScope/alpaca-gpt4-data-en#500' \
139
+ 'swift/self-cognition#500' \
140
+ --torch_dtype bfloat16 \
141
+ --num_train_epochs 1 \
142
+ --per_device_train_batch_size 1 \
143
+ --per_device_eval_batch_size 1 \
144
+ --learning_rate 1e-4 \
145
+ --lora_rank 8 \
146
+ --lora_alpha 32 \
147
+ --target_modules all-linear \
148
+ --gradient_accumulation_steps 16 \
149
+ --eval_steps 50 \
150
+ --save_steps 50 \
151
+ --save_total_limit 2 \
152
+ --logging_steps 5 \
153
+ --max_length 2048 \
154
+ --output_dir output \
155
+ --system 'You are a helpful assistant.' \
156
+ --warmup_ratio 0.05 \
157
+ --dataloader_num_workers 4 \
158
+ --model_author swift \
159
+ --model_name swift-robot
160
+ ```
161
+
162
+ 小贴士:
163
+ - 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
164
+ - `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
165
+ - 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
166
+ - 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
167
+
168
+ 训练完成后,使用以下命令对训练后的权重进行推理:
169
+ - 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
170
+
171
+ ```shell
172
+ # 使用交互式命令行进行推理
173
+ CUDA_VISIBLE_DEVICES=0 \
174
+ swift infer \
175
+ --adapters output/vx-xxx/checkpoint-xxx \
176
+ --stream true \
177
+ --temperature 0 \
178
+ --max_new_tokens 2048
179
+
180
+ # merge-lora并使用vLLM进行推理加速
181
+ CUDA_VISIBLE_DEVICES=0 \
182
+ swift infer \
183
+ --adapters output/vx-xxx/checkpoint-xxx \
184
+ --stream true \
185
+ --merge_lora true \
186
+ --infer_backend vllm \
187
+ --max_model_len 8192 \
188
+ --temperature 0 \
189
+ --max_new_tokens 2048
190
+ ```
191
+
192
+ 最后,使用以下命令将模型推送到ModelScope:
193
+ ```shell
194
+ CUDA_VISIBLE_DEVICES=0 \
195
+ swift export \
196
+ --adapters output/vx-xxx/checkpoint-xxx \
197
+ --push_to_hub true \
198
+ --hub_model_id '<your-model-id>' \
199
+ --hub_token '<your-sdk-token>' \
200
+ --use_hf false
201
+ ```
202
+
203
+ ### Web-UI
204
+
205
+ Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
206
+
207
+ ```shell
208
+ swift web-ui
209
+ ```
210
+ ![image.png](./docs/resources/web-ui.jpg)
211
+
212
+ ### 使用Python
213
+ ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
214
+
215
+ 训练:
216
+ ```python
217
+ # 获取模型和template,并加入可训练的LoRA模块
218
+ model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
219
+ template = get_template(model.model_meta.template, tokenizer, ...)
220
+ model = Swift.prepare_model(model, lora_config)
221
+
222
+ # 下载并载入数据集,并将文本encode成tokens
223
+ train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
224
+ train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
225
+ val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
226
+
227
+ # 进行训练
228
+ trainer = Seq2SeqTrainer(
229
+ model=model,
230
+ args=training_args,
231
+ data_collator=template.data_collator,
232
+ train_dataset=train_dataset,
233
+ eval_dataset=val_dataset,
234
+ template=template,
235
+ )
236
+ trainer.train()
237
+ ```
238
+
239
+ 推理:
240
+ ```python
241
+ # 使用原生pytorch引擎进行推理
242
+ engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
243
+ infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
244
+ request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
245
+
246
+ resp_list = engine.infer([infer_request], request_config)
247
+ print(f'response: {resp_list[0].choices[0].message.content}')
248
+ ```
249
+
250
+ ## ✨ 如何使用
251
+
252
+ 这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
253
+
254
+ - 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
255
+ - 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
256
+
257
+ | 常用链接 |
258
+ | ------ |
259
+ | [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
260
+ | [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
261
+ | [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
262
+ | [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
263
+
264
+ ### 训练
265
+ 支持的训练方法:
266
+
267
+ | 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
268
+ | ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
269
+ | 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
270
+ | 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
271
+ | DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
272
+ | GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
273
+ | 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
274
+ | PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
275
+ | KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
276
+ | CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
277
+ | SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
278
+ | ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
279
+ | 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
280
+ | Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
281
+
282
+
283
+ 预训练:
284
+ ```shell
285
+ # 8*A100
286
+ NPROC_PER_NODE=8 \
287
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
288
+ swift pt \
289
+ --model Qwen/Qwen2.5-7B \
290
+ --dataset swift/chinese-c4 \
291
+ --streaming true \
292
+ --train_type full \
293
+ --deepspeed zero2 \
294
+ --output_dir output \
295
+ --max_steps 10000 \
296
+ ...
297
+ ```
298
+
299
+ 微调:
300
+ ```shell
301
+ CUDA_VISIBLE_DEVICES=0 swift sft \
302
+ --model Qwen/Qwen2.5-7B-Instruct \
303
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
304
+ --train_type lora \
305
+ --output_dir output \
306
+ ...
307
+ ```
308
+
309
+ RLHF:
310
+ ```shell
311
+ CUDA_VISIBLE_DEVICES=0 swift rlhf \
312
+ --rlhf_type dpo \
313
+ --model Qwen/Qwen2.5-7B-Instruct \
314
+ --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
315
+ --train_type lora \
316
+ --output_dir output \
317
+ ...
318
+ ```
319
+
320
+
321
+ ### 推理
322
+ ```shell
323
+ CUDA_VISIBLE_DEVICES=0 swift infer \
324
+ --model Qwen/Qwen2.5-7B-Instruct \
325
+ --stream true \
326
+ --infer_backend pt \
327
+ --max_new_tokens 2048
328
+
329
+ # LoRA
330
+ CUDA_VISIBLE_DEVICES=0 swift infer \
331
+ --model Qwen/Qwen2.5-7B-Instruct \
332
+ --adapters swift/test_lora \
333
+ --stream true \
334
+ --infer_backend pt \
335
+ --temperature 0 \
336
+ --max_new_tokens 2048
337
+ ```
338
+
339
+ ### 界面推理
340
+ ```shell
341
+ CUDA_VISIBLE_DEVICES=0 swift app \
342
+ --model Qwen/Qwen2.5-7B-Instruct \
343
+ --stream true \
344
+ --infer_backend pt \
345
+ --max_new_tokens 2048 \
346
+ --lang zh
347
+ ```
348
+
349
+ ### 部署
350
+ ```shell
351
+ CUDA_VISIBLE_DEVICES=0 swift deploy \
352
+ --model Qwen/Qwen2.5-7B-Instruct \
353
+ --infer_backend vllm
354
+ ```
355
+
356
+ ### 采样
357
+ ```shell
358
+ CUDA_VISIBLE_DEVICES=0 swift sample \
359
+ --model LLM-Research/Meta-Llama-3.1-8B-Instruct \
360
+ --sampler_engine pt \
361
+ --num_return_sequences 5 \
362
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
363
+ ```
364
+
365
+ ### 评测
366
+ ```shell
367
+ CUDA_VISIBLE_DEVICES=0 swift eval \
368
+ --model Qwen/Qwen2.5-7B-Instruct \
369
+ --infer_backend lmdeploy \
370
+ --eval_backend OpenCompass \
371
+ --eval_dataset ARC_c
372
+ ```
373
+
374
+ ### 量化
375
+ ```shell
376
+ CUDA_VISIBLE_DEVICES=0 swift export \
377
+ --model Qwen/Qwen2.5-7B-Instruct \
378
+ --quant_bits 4 --quant_method awq \
379
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
380
+ --output_dir Qwen2.5-7B-Instruct-AWQ
381
+ ```
382
+
383
+ ### 推送模型
384
+ ```shell
385
+ swift export \
386
+ --model <model-path> \
387
+ --push_to_hub true \
388
+ --hub_model_id '<model-id>' \
389
+ --hub_token '<sdk-token>'
390
+ ```
391
+
392
+
393
+ ## 🏛 License
394
+
395
+ 本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
396
+
397
+ ## 📎 引用
398
+
399
+ ```bibtex
400
+ @misc{zhao2024swiftascalablelightweightinfrastructure,
401
+ title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
402
+ author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
403
+ year={2024},
404
+ eprint={2408.05517},
405
+ archivePrefix={arXiv},
406
+ primaryClass={cs.CL},
407
+ url={https://arxiv.org/abs/2408.05517},
408
+ }
409
+ ```
410
+
411
+ ## Star History
412
+
413
+ [![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date)
ms-swift/.ipynb_checkpoints/dataset-checkpoint.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
2
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
3
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
4
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
5
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
6
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
7
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
8
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
9
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
10
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
11
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
12
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
13
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
14
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
15
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
16
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
17
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
18
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
19
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
20
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
21
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
22
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
23
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
24
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
25
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
26
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
27
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
28
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
29
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
30
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
31
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
32
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
33
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
34
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
35
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
36
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
37
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
38
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
39
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
40
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
41
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
42
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
43
+ {"messages": [{"role": "user", "content": "<audio>语音��了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
44
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
45
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
46
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
47
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
48
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
49
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
50
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
51
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
52
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
53
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
54
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
55
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
56
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
57
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
58
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
59
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
60
+ {"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
ms-swift/.ipynb_checkpoints/dataset_overlap5s716_gemini-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/.ipynb_checkpoints/gen_data-checkpoint.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def get_prompt_for_file(filename):
5
+ if 'isoverlap' in filename:
6
+ return overlap_prompt
7
+ # elif 'issilence' in filename:
8
+ # return silence_prompt
9
+ # elif 'speaker_segments' in filename:
10
+ # return speaker_prompt
11
+ # elif 'transcriptions' in filename:
12
+ # return transcript_prompt
13
+ # else:
14
+ # raise ValueError(f"No matching prompt found for {filename}")
15
+ # return None
16
+
17
+ output_path = "/root/ms-swift/dataset_Overlap2.json"
18
+
19
+ # with open(input_path, "r") as fin:
20
+ # input_data = json.load(fin)
21
+
22
+ www = "hello"
23
+
24
+ www = (
25
+ "# Dialogue Response Evaluation\n\n"
26
+ "**IMPORTANT:** Evaluation must include`<score>` rating.\n\n"
27
+ "Listen to the dialogue recording (two sentences, 1-second pause in between). Evaluate the quality of the **second sentence** as a response to the first, focusing on **text relevance** and the **appropriateness** of **Linguistic information (a range of paralinguistic information such as emotion/age/pitch/speed/volume)**.\n"
28
+ "**Note:** Focus on evaluating the appropriateness of the second sentence relative to the first, even if the first sentence itself contains contradictory information.\n\n"
29
+ "## Scoring Criteria\n\n"
30
+ "**1 points**: Text content is irrelevant or incorrect or illogical.(low intelligence)\n"
31
+ "**3 points**: Text is relevant, but paralinguistic information is **inappropriate** for the context.(low emotional quotient)\n"
32
+ "**5 points**: Text is relevant, and paralinguistic information is **appropriate** for the context, resulting in effective communication.(High intelligence and emotional intelligence.)\n\n"
33
+ "## Evaluation Requirements\n\n"
34
+ "Response **MUST** follow this format:\n\n"
35
+ "<score>X</score> (**X is 1, 3, or 5**)\n\n")
36
+
37
+ # www = (
38
+ # "# Interactional Dialogue Evaluation\n\n"
39
+ # "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
40
+ # "Listen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n"
41
+ # "**Response Relevance:** \n"
42
+ # "**logical consistency, topic coherence**\n"
43
+ # "**Interactional Fluency:**\n"
44
+ # "**Strictly detect dual-tracked vocal overlap >3s (cross-channel analysis)**\n"
45
+ # "**Pauses >5s between turns (must evaluate) \n\n**"
46
+ # "**Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
47
+ # "## Scoring Criteria\n"
48
+ # "Assign a single holistic score based on the combined evaluation:\n"
49
+ # "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
50
+ # "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
51
+ # "## Evaluation Output Format:\n"
52
+ # "Strictly follow this template:\n"
53
+ # "<response think>\n"
54
+ # "[Analysing Response Relevance and giving reasons for scoring...]\n"
55
+ # "</response think>\n"
56
+ # "<fluency think>\n"
57
+ # "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
58
+ # "</fluency think>\n"
59
+ # "<overall score>X</overall score>\n"
60
+
61
+ # )
62
+ # www = (
63
+ # "# Interactional Dialogue Evaluation\n\n"
64
+ # "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
65
+ # "Listen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n"
66
+ # "**Response Relevance:** \n"
67
+ # "**logical consistency, topic coherence**\n"
68
+ # "**Interactional Fluency:**\n"
69
+ # "**Strictly detect dual-tracked vocal overlap >3s (cross-channel analysis)**\n"
70
+ # "**Pauses >5s between turns (must evaluate) \n\n**"
71
+ # "**Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
72
+ # "## Scoring Criteria\n"
73
+ # "Assign a single holistic score based on the combined evaluation:\n"
74
+ # "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
75
+ # "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
76
+ # "## Evaluation Output Format:\n"
77
+ # "Strictly follow this template:\n"
78
+ # "<response think>\n"
79
+ # "[Analysing Response Relevance and giving reasons for scoring...]\n"
80
+ # "</response think>\n"
81
+ # "<fluency think>\n"
82
+ # "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
83
+ # "</fluency think>\n"
84
+ # "<overall score>X</overall score>\n"
85
+
86
+ # )
87
+ overlap_prompt = (
88
+ "Analyze the dual-channel audio and identify segments where multiple speakers are talking simultaneously for more than 3 seconds. \n"
89
+ "Simply tell me when the overlap starts and ends in MM:SS format. \n"
90
+ "Just one simple sentence about the overlap timing. Keep the word count within 40 words."
91
+ )
92
+
93
+ silence_prompt = (
94
+ "Analyze the dual-channel audio and identify segments where multiple speakers are silent for more than 3 seconds. \n"
95
+ "Simply tell me when the silence starts and ends in MM:SS format. \n"
96
+ "Just one simple sentence about the silence timing. Keep the word count within 40 words."
97
+ )
98
+
99
+ speaker_prompt = (
100
+ "Analyze the dual-channel audio and detect individual speakers. \n"
101
+ "List the speaking segments for each speaker in MM:SS-MM:SS format. \n"
102
+ "Only output speaker labels and time segments in a similar format. Do not include any explanation.\n"
103
+ "Format the output like this example: \n"
104
+ "Speaker A: 00:00-00:13, 00:15-00:27, 00:33-00:37\n"
105
+ "Speaker B: 00:04-00:14, 00:27-00:32, 00:38-00:39 \n"
106
+ )
107
+
108
+ transcript_prompt = (
109
+ "Analyze the dual-channel audio and transcript each speaker's sentences with timestamps. \n"
110
+ "List the speaking segments and transcript text for each speaker in MM:SS-MM:SS format. \n"
111
+ "Only output time segments, speaker labels, and transcript text in a similar format. Do not include any explanation.\n"
112
+ "Format the output like this example: \n"
113
+ "[00:00 - 00:13] Speaker A: transcript text \n"
114
+ "[00:15 - 00:27] Speaker B: transcript text \n"
115
+ )
116
+
117
+ # Process files in the silence_overlaps directory
118
+ input_dir = "/root/ms-swift/silence_overlaps/only_overlap"
119
+ all_data = []
120
+
121
+ # Process each file
122
+ for filename in os.listdir(input_dir):
123
+ input_path = os.path.join(input_dir, filename)
124
+
125
+ # Get the appropriate prompt for this file
126
+ prompt = get_prompt_for_file(filename)
127
+ if prompt is None:
128
+ print(f"Skipping {filename} - no matching prompt found")
129
+ continue
130
+
131
+ # Read input data
132
+ with open(input_path, "r") as fin:
133
+ input_data = json.load(fin)
134
+
135
+ # Process each item
136
+ for item in input_data:
137
+ data = {
138
+ "messages": [
139
+ {"role": "user",
140
+ "content": f"<audio>{prompt}"
141
+ },
142
+ {"role": "assistant", "content": item["model_output"]}
143
+ ],
144
+ "audios": [
145
+ item["audio_url"]
146
+ ]
147
+ }
148
+ all_data.append(data)
149
+
150
+ # Write all processed data to a single output file
151
+ with open(output_path, "w", encoding="utf-8") as fout:
152
+ for data in all_data:
153
+ json.dump(data, fout, ensure_ascii=False)
154
+ fout.write('\n')
ms-swift/.ipynb_checkpoints/overlap5s716_gemini-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/.ipynb_checkpoints/setup-checkpoint.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # !/usr/bin/env python
3
+ import os
4
+ from setuptools import find_packages, setup
5
+ from typing import List
6
+
7
+
8
+ def readme():
9
+ with open('README.md', encoding='utf-8') as f:
10
+ content = f.read()
11
+ return content
12
+
13
+
14
+ version_file = 'swift/version.py'
15
+
16
+
17
+ def get_version():
18
+ with open(version_file, 'r', encoding='utf-8') as f:
19
+ exec(compile(f.read(), version_file, 'exec'))
20
+ return locals()['__version__']
21
+
22
+
23
+ def parse_requirements(fname='requirements.txt', with_version=True):
24
+ """
25
+ Parse the package dependencies listed in a requirements file but strips
26
+ specific versioning information.
27
+
28
+ Args:
29
+ fname (str): path to requirements file
30
+ with_version (bool, default=False): if True include version specs
31
+
32
+ Returns:
33
+ List[str]: list of requirements items
34
+
35
+ CommandLine:
36
+ python -c "import setup; print(setup.parse_requirements())"
37
+ """
38
+ import re
39
+ import sys
40
+ from os.path import exists
41
+ require_fpath = fname
42
+
43
+ def parse_line(line):
44
+ """
45
+ Parse information from a line in a requirements text file
46
+ """
47
+ if line.startswith('-r '):
48
+ # Allow specifying requirements in other files
49
+ target = line.split(' ')[1]
50
+ relative_base = os.path.dirname(fname)
51
+ absolute_target = os.path.join(relative_base, target)
52
+ for info in parse_require_file(absolute_target):
53
+ yield info
54
+ else:
55
+ info = {'line': line}
56
+ if line.startswith('-e '):
57
+ info['package'] = line.split('#egg=')[1]
58
+ else:
59
+ # Remove versioning from the package
60
+ pat = '(' + '|'.join(['>=', '==', '>']) + ')'
61
+ parts = re.split(pat, line, maxsplit=1)
62
+ parts = [p.strip() for p in parts]
63
+
64
+ info['package'] = parts[0]
65
+ if len(parts) > 1:
66
+ op, rest = parts[1:]
67
+ if ';' in rest:
68
+ # Handle platform specific dependencies
69
+ # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
70
+ version, platform_deps = map(str.strip, rest.split(';'))
71
+ info['platform_deps'] = platform_deps
72
+ else:
73
+ version = rest # NOQA
74
+ info['version'] = (op, version)
75
+ yield info
76
+
77
+ def parse_require_file(fpath):
78
+ with open(fpath, 'r', encoding='utf-8') as f:
79
+ for line in f.readlines():
80
+ line = line.strip()
81
+ if line.startswith('http'):
82
+ print('skip http requirements %s' % line)
83
+ continue
84
+ if line and not line.startswith('#') and not line.startswith('--'):
85
+ for info in parse_line(line):
86
+ yield info
87
+ elif line and line.startswith('--find-links'):
88
+ eles = line.split()
89
+ for e in eles:
90
+ e = e.strip()
91
+ if 'http' in e:
92
+ info = dict(dependency_links=e)
93
+ yield info
94
+
95
+ def gen_packages_items():
96
+ items = []
97
+ deps_link = []
98
+ if exists(require_fpath):
99
+ for info in parse_require_file(require_fpath):
100
+ if 'dependency_links' not in info:
101
+ parts = [info['package']]
102
+ if with_version and 'version' in info:
103
+ parts.extend(info['version'])
104
+ if not sys.version.startswith('3.4'):
105
+ # apparently package_deps are broken in 3.4
106
+ platform_deps = info.get('platform_deps')
107
+ if platform_deps is not None:
108
+ parts.append(';' + platform_deps)
109
+ item = ''.join(parts)
110
+ items.append(item)
111
+ else:
112
+ deps_link.append(info['dependency_links'])
113
+ return items, deps_link
114
+
115
+ return gen_packages_items()
116
+
117
+
118
+ if __name__ == '__main__':
119
+ install_requires, deps_link = parse_requirements('requirements.txt')
120
+ extra_requires = {}
121
+ all_requires = []
122
+ extra_requires['eval'], _ = parse_requirements('requirements/eval.txt')
123
+ extra_requires['swanlab'], _ = parse_requirements('requirements/swanlab.txt')
124
+ extra_requires['seq_parallel'], _ = parse_requirements('requirements/seq_parallel.txt')
125
+ all_requires.extend(install_requires)
126
+ all_requires.extend(extra_requires['eval'])
127
+ all_requires.extend(extra_requires['seq_parallel'])
128
+ all_requires.extend(extra_requires['swanlab'])
129
+ extra_requires['all'] = all_requires
130
+
131
+ setup(
132
+ name='ms_swift',
133
+ version=get_version(),
134
+ description='Swift: Scalable lightWeight Infrastructure for Fine-Tuning',
135
+ long_description=readme(),
136
+ long_description_content_type='text/markdown',
137
+ author='DAMO ModelScope teams',
138
+ author_email='contact@modelscope.cn',
139
+ keywords='python, petl, efficient tuners',
140
+ url='https://github.com/modelscope/swift',
141
+ packages=find_packages(exclude=('configs', 'demo')),
142
+ include_package_data=True,
143
+ package_data={
144
+ '': ['*.h', '*.cpp', '*.cu'],
145
+ },
146
+ classifiers=[
147
+ 'Development Status :: 4 - Beta',
148
+ 'License :: OSI Approved :: Apache Software License',
149
+ 'Operating System :: OS Independent',
150
+ 'Programming Language :: Python :: 3',
151
+ 'Programming Language :: Python :: 3.8',
152
+ 'Programming Language :: Python :: 3.9',
153
+ 'Programming Language :: Python :: 3.10',
154
+ 'Programming Language :: Python :: 3.11',
155
+ 'Programming Language :: Python :: 3.12',
156
+ ],
157
+ license='Apache License 2.0',
158
+ tests_require=parse_requirements('requirements/tests.txt'),
159
+ install_requires=install_requires,
160
+ extras_require=extra_requires,
161
+ entry_points={
162
+ 'console_scripts': ['swift=swift.cli.main:cli_main', 'megatron=swift.cli._megatron.main:cli_main']
163
+ },
164
+ dependency_links=deps_link,
165
+ zip_safe=False)
ms-swift/.ipynb_checkpoints/test-checkpoint.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 \
2
+ swift infer \
3
+ --adapters /root/autodl-tmp/output_7B_SFT/v0-20250605-155458/checkpoint-1095 \
4
+ --stream true \
5
+ --temperature 0 \
6
+ --max_new_tokens 2048
ms-swift/.ipynb_checkpoints/train-checkpoint.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ CUDA_VISIBLE_DEVICES=0 swift sft \
3
+ --model /root/autodl-tmp/Qwen2.5-Omni-7B \
4
+ --dataset /root/ms-swift/dataset_Overlap2.json \
5
+ --train_type full \
6
+ --output_dir /root/autodl-tmp/output_7B_SFT \
7
+ --torch_dtype bfloat16 \
8
+ --num_train_epochs 3 \
9
+ --per_device_train_batch_size 1 \
10
+ --per_device_eval_batch_size 1 \
11
+ # ...
12
+
13
+ # # 8*A100
14
+ # NPROC_PER_NODE=8 \
15
+ # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
16
+ # swift pt \
17
+ # --model Qwen/Qwen2.5-7B \
18
+ # --dataset swift/chinese-c4 \
19
+ # --streaming true \
20
+ # --train_type full \
21
+ # --deepspeed zero2 \
22
+ # --output_dir output \
23
+ # --max_steps 10000 \
24
+ # ...
25
+
26
+
27
+
28
+ # --lora_rank 8 \
29
+ # --lora_alpha 32 \
30
+ # --target_modules all-linear \
31
+ # --gradient_accumulation_steps 16 \
32
+ # --eval_steps 50 \
33
+ # --save_steps 50 \
34
+ # --save_total_limit 2 \
35
+ # --logging_steps 5 \
36
+ # --max_length 2048 \
37
+ # --output_dir output \
38
+ # --system 'You are a helpful assistant.' \
39
+ # --warmup_ratio 0.05 \
40
+ # --dataloader_num_workers 4 \
41
+ # --model_author swift \
42
+ # --model_name swift-robot
ms-swift/asset/banner.png ADDED

Git LFS Details

  • SHA256: aed7b0ac0bbb353df62f86b80e26eeab10fc69a0a49de161c544b51ab4ea9bea
  • Pointer size: 131 Bytes
  • Size of remote file: 381 kB
ms-swift/docs/resources/dpo_data.png ADDED

Git LFS Details

  • SHA256: 8d87ca58f3ac43a79836ba40f5d9cb788b8fabc26d717fb4e8eb1f1400f6598d
  • Pointer size: 131 Bytes
  • Size of remote file: 355 kB
ms-swift/docs/resources/grpo_clevr_count.png ADDED

Git LFS Details

  • SHA256: 7192dc4f04801dbdff30bed098a16a7e21212a773ba7b6dc1424b261feca366f
  • Pointer size: 131 Bytes
  • Size of remote file: 671 kB
ms-swift/docs/resources/grpo_code.png ADDED

Git LFS Details

  • SHA256: 5f396d9ce5ce9de323d7a6ffa8d53f783d938a242088b191f46a293268193b64
  • Pointer size: 131 Bytes
  • Size of remote file: 294 kB
ms-swift/docs/resources/grpo_countdown.png ADDED

Git LFS Details

  • SHA256: 1b55fe6864e0c92549940d6989d92b3ab22be38a035cff3694525252737fc91e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
ms-swift/docs/resources/grpo_countdown_1.png ADDED

Git LFS Details

  • SHA256: b78dc3ce1cd541e76f2c557dea3aff06b278bb3b5413946a92c584cf42c1369f
  • Pointer size: 131 Bytes
  • Size of remote file: 785 kB
ms-swift/docs/resources/grpo_geoqa.png ADDED

Git LFS Details

  • SHA256: 71246376b16f2ff288542dca2ff31532b16ef99f5e862797463d548e447e1f8d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.24 MB
ms-swift/docs/resources/grpo_openr1_multimodal.png ADDED

Git LFS Details

  • SHA256: 050f56792468a4c9797a90314e322c16dd916bde3be24a7ce7c7b96381e70d9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.3 MB
ms-swift/docs/resources/kto_data.png ADDED

Git LFS Details

  • SHA256: becb20db8a8b890718f3f9e1752dc6669a018ed19a344f155fbe5123521aff49
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
ms-swift/docs/resources/web-ui-en.jpg ADDED

Git LFS Details

  • SHA256: 6ad958b680650b4f900f1a99560e825975f1a5807702c6fbbaef8e4f0e2a5efe
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB
ms-swift/docs/resources/web-ui.jpg ADDED

Git LFS Details

  • SHA256: 5e83bb4b4ecda9386286b99c6e83551f1dd1fdcdaf7be2efa3117208e3806000
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
ms-swift/silence_overlaps.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ae3d5b7926569fa96005246be5397fb0dc10bf23281006fc1791dac1771d5e
3
+ size 645024
ms-swift/silence_overlaps/.ipynb_checkpoints/clean_wrong-checkpoint.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Tuple
4
+
5
+ def parse_timestamp(timestamp: str) -> Tuple[int, int]:
6
+ """Convert timestamp string like '00:15' to seconds."""
7
+ minutes, seconds = map(int, timestamp.split(':'))
8
+ return minutes * 60 + seconds
9
+
10
+ def extract_time_range(entry: str) -> Tuple[int, int]:
11
+ """Extract start and end times from an entry like '[00:00 - 00:13]'."""
12
+ match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\]', entry)
13
+ if not match:
14
+ return None
15
+ start_time = parse_timestamp(match.group(1))
16
+ end_time = parse_timestamp(match.group(2))
17
+ return (start_time, end_time)
18
+
19
+ def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
20
+ """Check if two time ranges overlap."""
21
+ start1, end1 = range1
22
+ start2, end2 = range2
23
+ return not (end1 <= start2 or end2 <= start1)
24
+
25
+ def clean_transcript(transcript: str) -> str:
26
+ """Clean a single transcript by removing overlapping segments."""
27
+ lines = transcript.split('\n')
28
+ cleaned_lines = []
29
+ time_ranges = []
30
+
31
+ for line in lines:
32
+ if not line.strip():
33
+ continue
34
+
35
+ time_range = extract_time_range(line)
36
+ if time_range is None:
37
+ continue
38
+
39
+ # Check for overlaps with existing ranges
40
+ has_conflict = False
41
+ for existing_range in time_ranges:
42
+ if has_overlap(time_range, existing_range):
43
+ has_conflict = True
44
+ break
45
+
46
+ if not has_conflict:
47
+ time_ranges.append(time_range)
48
+ cleaned_lines.append(line)
49
+
50
+ return '\n'.join(cleaned_lines)
51
+
52
+ def process_file(input_file: str, output_file: str):
53
+ """Process the JSON file and clean overlapping transcriptions."""
54
+ with open(input_file, 'r', encoding='utf-8') as f:
55
+ data = json.load(f)
56
+
57
+ if isinstance(data, dict):
58
+ data = [data]
59
+
60
+ cleaned_data = []
61
+ for entry in data:
62
+ if 'model_output' in entry:
63
+ entry['model_output'] = clean_transcript(entry['model_output'])
64
+ cleaned_data.append(entry)
65
+
66
+ with open(output_file, 'w', encoding='utf-8') as f:
67
+ json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
68
+
69
+ if __name__ == '__main__':
70
+ input_file = 'silence_overlaps/overlap5s_transcriptions.json'
71
+ output_file = 'silence_overlaps/cleaned_transcriptions.json'
72
+ process_file(input_file, output_file)
73
+ print(f"Cleaned transcriptions have been saved to {output_file}")
ms-swift/silence_overlaps/.ipynb_checkpoints/cleaned_transcriptions-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript2-checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ []
ms-swift/silence_overlaps/700/merge_and_shuffle_json.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from typing import Dict, List, Any
5
+
6
+ def load_json_files() -> List[Dict[str, Any]]:
7
+ """加载当前目录下所有JSON文件的内容"""
8
+ json_data = []
9
+ for filename in os.listdir('.'):
10
+ if filename.endswith('.json'):
11
+ try:
12
+ with open(filename, 'r', encoding='utf-8') as f:
13
+ data = json.load(f)
14
+ json_data.append(data)
15
+ print(f"已加载文件: {filename}")
16
+ except Exception as e:
17
+ print(f"加载文件 {filename} 时出错: {e}")
18
+ return json_data
19
+
20
+ def merge_and_shuffle(json_data: List[Dict[str, Any]]) -> Dict[str, Any]:
21
+ """合并所有JSON数据并随机打乱条目"""
22
+ merged_data = {}
23
+ # 合并所有JSON数据
24
+ for data in json_data:
25
+ merged_data.update(data)
26
+
27
+ # 提取所有条目并打乱顺序
28
+ items = list(merged_data.items())
29
+ random.shuffle(items)
30
+
31
+ # 创建新的有序字典
32
+ shuffled_data = {}
33
+ for key, value in items:
34
+ shuffled_data[key] = value
35
+
36
+ return shuffled_data
37
+
38
+ def save_shuffled_data(shuffled_data: Dict[str, Any], output_file: str = 'merged_shuffled.json') -> None:
39
+ """将打乱后的数据保存到JSON文件"""
40
+ with open(output_file, 'w', encoding='utf-8') as f:
41
+ json.dump(shuffled_data, f, ensure_ascii=False, indent=2)
42
+ print(f"已保存到文件: {output_file}")
43
+
44
+ def main():
45
+ # 设置随机种子,以便结果可重现
46
+ random.seed(42)
47
+
48
+ # 加载JSON文件
49
+ json_data = load_json_files()
50
+ if not json_data:
51
+ print("未找到JSON文件!")
52
+ return
53
+
54
+ # 合并并打乱数据
55
+ shuffled_data = merge_and_shuffle(json_data)
56
+
57
+ # 保存结果
58
+ save_shuffled_data(shuffled_data)
59
+
60
+ if __name__ == "__main__":
61
+ main()
ms-swift/silence_overlaps/700/split_train_test.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from glob import glob
5
+
6
+ random.seed(42) # 保证可复现
7
+
8
+ json_files = glob(os.path.join(os.path.dirname(__file__), '*.json'))
9
+
10
+ for file in json_files:
11
+ base = os.path.basename(file)
12
+ if base.startswith('split_train_test') or base.endswith('_train.json') or base.endswith('_test.json'):
13
+ continue
14
+ with open(file, 'r', encoding='utf-8') as f:
15
+ try:
16
+ data = json.load(f)
17
+ except Exception as e:
18
+ print(f"Error reading {file}: {e}")
19
+ continue
20
+
21
+ # 将数据转换为列表格式
22
+ data_list = data if isinstance(data, list) else list(data.values())
23
+ random.shuffle(data_list)
24
+
25
+ # 选择5条数据作为测试集
26
+ test_data = data_list[:5]
27
+ train_data = data_list[5:]
28
+
29
+ train_file = os.path.splitext(file)[0] + '_train.json'
30
+ test_file = os.path.splitext(file)[0] + '_test.json'
31
+
32
+ with open(train_file, 'w', encoding='utf-8') as f:
33
+ json.dump(train_data, f, ensure_ascii=False, indent=2)
34
+ with open(test_file, 'w', encoding='utf-8') as f:
35
+ json.dump(test_data, f, ensure_ascii=False, indent=2)
36
+ print(f"{base} 已分为 {train_file} (训练集: {len(train_data)}条) 和 {test_file} (测试集: {len(test_data)}条)")
ms-swift/silence_overlaps/700/train/overlap5s_isoverlap_train.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/700/train/overlap5s_speaker_segments_train.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/clean_wrong.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Tuple
4
+
5
+ def parse_timestamp(timestamp: str) -> Tuple[int, int]:
6
+ """Convert timestamp string like '00:15' to seconds."""
7
+ minutes, seconds = map(int, timestamp.split(':'))
8
+ return minutes * 60 + seconds
9
+
10
+ def extract_time_range(entry: str) -> Tuple[int, int]:
11
+ """Extract start and end times from an entry like '[00:00 - 00:13]'."""
12
+ match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\]', entry)
13
+ if not match:
14
+ return None
15
+ start_time = parse_timestamp(match.group(1))
16
+ end_time = parse_timestamp(match.group(2))
17
+ return (start_time, end_time)
18
+
19
+ def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
20
+ """Check if two time ranges overlap."""
21
+ start1, end1 = range1
22
+ start2, end2 = range2
23
+ return not (end1 <= start2 or end2 <= start1)
24
+
25
+ def clean_transcript(transcript: str) -> str:
26
+ """Clean a single transcript by removing overlapping segments."""
27
+ lines = transcript.split('\n')
28
+ cleaned_lines = []
29
+ time_ranges = []
30
+
31
+ for line in lines:
32
+ if not line.strip():
33
+ continue
34
+
35
+ time_range = extract_time_range(line)
36
+ if time_range is None:
37
+ continue
38
+
39
+ # Check for overlaps with existing ranges
40
+ has_conflict = False
41
+ for existing_range in time_ranges:
42
+ if has_overlap(time_range, existing_range):
43
+ has_conflict = True
44
+ break
45
+
46
+ if not has_conflict:
47
+ time_ranges.append(time_range)
48
+ cleaned_lines.append(line)
49
+
50
+ return '\n'.join(cleaned_lines)
51
+
52
+ def process_file(input_file: str, output_file: str):
53
+ """Process the JSON file and clean overlapping transcriptions."""
54
+ with open(input_file, 'r', encoding='utf-8') as f:
55
+ data = json.load(f)
56
+
57
+ if isinstance(data, dict):
58
+ data = [data]
59
+
60
+ cleaned_data = []
61
+ for entry in data:
62
+ if 'model_output' in entry:
63
+ entry['model_output'] = clean_transcript(entry['model_output'])
64
+ cleaned_data.append(entry)
65
+
66
+ with open(output_file, 'w', encoding='utf-8') as f:
67
+ json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
68
+
69
+ if __name__ == '__main__':
70
+ input_file = 'silence_overlaps/overlap5s_transcriptions.json'
71
+ output_file = 'silence_overlaps/cleaned_transcriptions.json'
72
+ process_file(input_file, output_file)
73
+ print(f"Cleaned transcriptions have been saved to {output_file}")
ms-swift/silence_overlaps/cleaned_transcriptions.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/overlap5s_isoverlap.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/overlap5s_speaker_segments.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/overlap5s_transcriptions.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/silence_isoverlaps.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/silence_issilence.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/transcriptions.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/swift/ui/llm_train/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import asyncio
3
+ import sys
4
+ from asyncio.subprocess import PIPE, STDOUT
5
+
6
+
7
+ async def run_and_get_log(*args, timeout=None):
8
+ process = await asyncio.create_subprocess_exec(*args, stdout=PIPE, stderr=STDOUT)
9
+ lines = []
10
+ while True:
11
+ try:
12
+ line = await asyncio.wait_for(process.stdout.readline(), timeout)
13
+ except asyncio.TimeoutError:
14
+ break
15
+ else:
16
+ if not line:
17
+ break
18
+ else:
19
+ lines.append(str(line))
20
+ return process, lines
21
+
22
+
23
+ def run_command_in_subprocess(*args, timeout):
24
+ if sys.platform == 'win32':
25
+ loop = asyncio.ProactorEventLoop()
26
+ asyncio.set_event_loop(loop)
27
+ else:
28
+ loop = asyncio.new_event_loop()
29
+ asyncio.set_event_loop(loop)
30
+ process, lines = loop.run_until_complete(run_and_get_log(*args, timeout=timeout))
31
+ return (loop, process), lines
32
+
33
+
34
+ def close_loop(handler):
35
+ loop, process = handler
36
+ process.kill()
37
+ loop.close()
ms-swift/swift/utils/__pycache__/logger.cpython-310.pyc ADDED
Binary file (3.08 kB). View file
 
ms-swift/swift/utils/__pycache__/torch_utils.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
ms-swift/swift/utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.58 kB). View file
 
ms-swift/swift/utils/utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import datetime as dt
3
+ import fnmatch
4
+ import glob
5
+ import importlib
6
+ import os
7
+ import random
8
+ import re
9
+ import shutil
10
+ import socket
11
+ import subprocess
12
+ import sys
13
+ import time
14
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+ from transformers import HfArgumentParser, enable_full_determinism, set_seed
20
+ from transformers.utils import strtobool
21
+
22
+ from .env import is_dist, is_dist_ta
23
+ from .logger import get_logger
24
+ from .np_utils import stat_array
25
+
26
+ logger = get_logger()
27
+
28
+
29
+ def check_json_format(obj: Any, token_safe: bool = True) -> Any:
30
+ if obj is None or isinstance(obj, (int, float, str, complex)): # bool is a subclass of int
31
+ return obj
32
+ if isinstance(obj, bytes):
33
+ return '<<<bytes>>>'
34
+ if isinstance(obj, (torch.dtype, torch.device)):
35
+ obj = str(obj)
36
+ return obj[len('torch.'):] if obj.startswith('torch.') else obj
37
+
38
+ if isinstance(obj, Sequence):
39
+ res = []
40
+ for x in obj:
41
+ res.append(check_json_format(x, token_safe))
42
+ elif isinstance(obj, Mapping):
43
+ res = {}
44
+ for k, v in obj.items():
45
+ if token_safe and isinstance(k, str) and '_token' in k and isinstance(v, str):
46
+ res[k] = None
47
+ else:
48
+ res[k] = check_json_format(v, token_safe)
49
+ else:
50
+ if token_safe:
51
+ unsafe_items = {}
52
+ for k, v in obj.__dict__.items():
53
+ if '_token' in k:
54
+ unsafe_items[k] = v
55
+ setattr(obj, k, None)
56
+ res = repr(obj)
57
+ # recover
58
+ for k, v in unsafe_items.items():
59
+ setattr(obj, k, v)
60
+ else:
61
+ res = repr(obj) # e.g. function, object
62
+ return res
63
+
64
+
65
+ def _get_version(work_dir: str) -> int:
66
+ if os.path.isdir(work_dir):
67
+ fnames = os.listdir(work_dir)
68
+ else:
69
+ fnames = []
70
+ v_list = [-1]
71
+ for fname in fnames:
72
+ m = re.match(r'v(\d+)', fname)
73
+ if m is None:
74
+ continue
75
+ v = m.group(1)
76
+ v_list.append(int(v))
77
+ return max(v_list) + 1
78
+
79
+
80
+ def format_time(seconds):
81
+ days = int(seconds // (24 * 3600))
82
+ hours = int((seconds % (24 * 3600)) // 3600)
83
+ minutes = int((seconds % 3600) // 60)
84
+ seconds = int(seconds % 60)
85
+
86
+ if days > 0:
87
+ time_str = f'{days}d {hours}h {minutes}m {seconds}s'
88
+ elif hours > 0:
89
+ time_str = f'{hours}h {minutes}m {seconds}s'
90
+ elif minutes > 0:
91
+ time_str = f'{minutes}m {seconds}s'
92
+ else:
93
+ time_str = f'{seconds}s'
94
+
95
+ return time_str
96
+
97
+
98
+ def deep_getattr(obj, attr: str, default=None):
99
+ attrs = attr.split('.')
100
+ for a in attrs:
101
+ if obj is None:
102
+ break
103
+ if isinstance(obj, dict):
104
+ obj = obj.get(a, default)
105
+ else:
106
+ obj = getattr(obj, a, default)
107
+ return obj
108
+
109
+
110
+ def seed_everything(seed: Optional[int] = None, full_determinism: bool = False, *, verbose: bool = True) -> int:
111
+
112
+ if seed is None:
113
+ seed_max = np.iinfo(np.int32).max
114
+ seed = random.randint(0, seed_max)
115
+
116
+ if full_determinism:
117
+ enable_full_determinism(seed)
118
+ else:
119
+ set_seed(seed)
120
+ if verbose:
121
+ logger.info(f'Global seed set to {seed}')
122
+ return seed
123
+
124
+
125
+ def add_version_to_work_dir(work_dir: str) -> str:
126
+ """add version"""
127
+ version = _get_version(work_dir)
128
+ time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
129
+ sub_folder = f'v{version}-{time}'
130
+ if (dist.is_initialized() and is_dist()) or is_dist_ta():
131
+ obj_list = [sub_folder]
132
+ dist.broadcast_object_list(obj_list)
133
+ sub_folder = obj_list[0]
134
+
135
+ work_dir = os.path.join(work_dir, sub_folder)
136
+ return work_dir
137
+
138
+
139
+ _T = TypeVar('_T')
140
+
141
+
142
+ def parse_args(class_type: Type[_T], argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
143
+ parser = HfArgumentParser([class_type])
144
+ if argv is None:
145
+ argv = sys.argv[1:]
146
+ if len(argv) > 0 and argv[0].endswith('.json'):
147
+ json_path = os.path.abspath(os.path.expanduser(argv[0]))
148
+ args, = parser.parse_json_file(json_path)
149
+ remaining_args = argv[1:]
150
+ else:
151
+ args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True)
152
+ return args, remaining_args
153
+
154
+
155
+ def lower_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
156
+ # The lower bound satisfying the condition "cond".
157
+ while lo < hi:
158
+ mid = (lo + hi) >> 1
159
+ if cond(mid):
160
+ hi = mid
161
+ else:
162
+ lo = mid + 1
163
+ return lo
164
+
165
+
166
+ def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
167
+ # The upper bound satisfying the condition "cond".
168
+ while lo < hi:
169
+ mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
170
+ if cond(mid):
171
+ lo = mid
172
+ else:
173
+ hi = mid - 1
174
+ return lo
175
+
176
+
177
+ def test_time(func: Callable[[], _T],
178
+ number: int = 1,
179
+ warmup: int = 0,
180
+ timer: Optional[Callable[[], float]] = None) -> _T:
181
+ # timer: e.g. time_synchronize
182
+ timer = timer if timer is not None else time.perf_counter
183
+
184
+ ts = []
185
+ res = None
186
+ # warmup
187
+ for _ in range(warmup):
188
+ res = func()
189
+
190
+ for _ in range(number):
191
+ t1 = timer()
192
+ res = func()
193
+ t2 = timer()
194
+ ts.append(t2 - t1)
195
+
196
+ ts = np.array(ts)
197
+ _, stat_str = stat_array(ts)
198
+ # print
199
+ logger.info(f'time[number={number}]: {stat_str}')
200
+ return res
201
+
202
+
203
+ def read_multi_line(addi_prompt: str = '') -> str:
204
+ res = []
205
+ prompt = f'<<<{addi_prompt} '
206
+ while True:
207
+ text = input(prompt) + '\n'
208
+ prompt = ''
209
+ res.append(text)
210
+ if text.endswith('#\n'):
211
+ res[-1] = text[:-2]
212
+ break
213
+ return ''.join(res)
214
+
215
+
216
+ def subprocess_run(command: List[str], env: Optional[Dict[str, str]] = None, stdout=None, stderr=None):
217
+ # stdoutm stderr: e.g. subprocess.PIPE.
218
+ resp = subprocess.run(command, env=env, stdout=stdout, stderr=stderr)
219
+ resp.check_returncode()
220
+ return resp
221
+
222
+
223
+ def get_env_args(args_name: str, type_func: Callable[[str], _T], default_value: Optional[_T]) -> Optional[_T]:
224
+ args_name_upper = args_name.upper()
225
+ value = os.getenv(args_name_upper)
226
+ if value is None:
227
+ value = default_value
228
+ log_info = (f'Setting {args_name}: {default_value}. '
229
+ f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.')
230
+ else:
231
+ if type_func is bool:
232
+ value = strtobool(value)
233
+ value = type_func(value)
234
+ log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.'
235
+ logger.info_once(log_info)
236
+ return value
237
+
238
+
239
+ def find_free_port(start_port: Optional[int] = None, retry: int = 100) -> int:
240
+ if start_port is None:
241
+ start_port = 0
242
+ for port in range(start_port, start_port + retry):
243
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
244
+ try:
245
+ sock.bind(('', port))
246
+ port = sock.getsockname()[1]
247
+ break
248
+ except OSError:
249
+ pass
250
+ return port
251
+
252
+
253
+ def copy_files_by_pattern(source_dir, dest_dir, patterns):
254
+ if not os.path.exists(dest_dir):
255
+ os.makedirs(dest_dir)
256
+
257
+ if isinstance(patterns, str):
258
+ patterns = [patterns]
259
+
260
+ for pattern in patterns:
261
+ pattern_parts = pattern.split(os.path.sep)
262
+ if len(pattern_parts) > 1:
263
+ subdir_pattern = os.path.sep.join(pattern_parts[:-1])
264
+ file_pattern = pattern_parts[-1]
265
+
266
+ for root, dirs, files in os.walk(source_dir):
267
+ rel_path = os.path.relpath(root, source_dir)
268
+ if rel_path == '.' or (rel_path != '.' and not fnmatch.fnmatch(rel_path, subdir_pattern)):
269
+ continue
270
+
271
+ for file in files:
272
+ if fnmatch.fnmatch(file, file_pattern):
273
+ file_path = os.path.join(root, file)
274
+ target_dir = os.path.join(dest_dir, rel_path)
275
+ if not os.path.exists(target_dir):
276
+ os.makedirs(target_dir)
277
+ dest_file = os.path.join(target_dir, file)
278
+
279
+ if not os.path.exists(dest_file):
280
+ shutil.copy2(file_path, dest_file)
281
+ else:
282
+ search_path = os.path.join(source_dir, pattern)
283
+ matched_files = glob.glob(search_path)
284
+
285
+ for file_path in matched_files:
286
+ if os.path.isfile(file_path):
287
+ file_name = os.path.basename(file_path)
288
+ destination = os.path.join(dest_dir, file_name)
289
+ if not os.path.exists(destination):
290
+ shutil.copy2(file_path, destination)
291
+
292
+
293
+ def split_list(ori_list, num_shards):
294
+ idx_list = np.linspace(0, len(ori_list), num_shards + 1)
295
+ shard = []
296
+ for i in range(len(idx_list) - 1):
297
+ shard.append(ori_list[int(idx_list[i]):int(idx_list[i + 1])])
298
+ return shard
299
+
300
+
301
+ def patch_getattr(obj_cls, item_name: str):
302
+ if hasattr(obj_cls, '_patch'): # avoid double patch
303
+ return
304
+
305
+ def __new_getattr__(self, key: str):
306
+ try:
307
+ return super(self.__class__, self).__getattr__(key)
308
+ except AttributeError:
309
+ if item_name in dir(self):
310
+ item = getattr(self, item_name)
311
+ return getattr(item, key)
312
+ raise
313
+
314
+ obj_cls.__getattr__ = __new_getattr__
315
+ obj_cls._patch = True
316
+
317
+
318
+ def import_external_file(file_path: str):
319
+ file_path = os.path.abspath(os.path.expanduser(file_path))
320
+ py_dir, py_file = os.path.split(file_path)
321
+ assert os.path.isdir(py_dir), f'py_dir: {py_dir}'
322
+ sys.path.insert(0, py_dir)
323
+ return importlib.import_module(py_file.split('.', 1)[0])
ms-swift/tests/__init__.py ADDED
File without changes
ms-swift/tests/app/test_app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def test_llm():
2
+ from swift.llm import app_main, AppArguments
3
+ app_main(AppArguments(model='Qwen/Qwen2.5-0.5B-Instruct'))
4
+
5
+
6
+ def test_lora():
7
+ from swift.llm import app_main, AppArguments
8
+ app_main(AppArguments(adapters='swift/test_lora', lang='en', studio_title='小黄'))
9
+
10
+
11
+ def test_mllm():
12
+ from swift.llm import app_main, AppArguments
13
+ app_main(AppArguments(model='Qwen/Qwen2-VL-7B-Instruct', stream=True))
14
+
15
+
16
+ def test_audio():
17
+ from swift.llm import AppArguments, app_main, DeployArguments, run_deploy
18
+ deploy_args = DeployArguments(model='Qwen/Qwen2-Audio-7B-Instruct', infer_backend='pt', verbose=False)
19
+
20
+ with run_deploy(deploy_args, return_url=True) as url:
21
+ app_main(AppArguments(model='Qwen2-Audio-7B-Instruct', base_url=url, stream=True))
22
+
23
+
24
+ if __name__ == '__main__':
25
+ test_mllm()
ms-swift/tests/llm/data/multi_modal_1.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"query": "<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>55555", "response": "66666"}
2
+ {"query": "<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img><img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>eeeee", "response": "fffff", "history": [["hello", "123"]]}
3
+ {"query": "EEEEE", "response": "FFFFF", "history": [["AAAAA", "BBBBB"], ["CCCCC", "DDDDD"]]}
ms-swift/tests/models/test_flash_attn.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from swift.llm import get_model_tokenizer
2
+
3
+ if __name__ == '__main__':
4
+ # model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', attn_impl='flash_attn')
5
+ # model, tokenizer = get_model_tokenizer('AIDC-AI/Ovis2-2B', attn_impl='flash_attn')
6
+ # model, tokenizer = get_model_tokenizer('OpenGVLab/InternVL2-2B', attn_impl='flash_attn')
7
+ model, tokenizer = get_model_tokenizer('Shanghai_AI_Laboratory/internlm3-8b-instruct', attn_impl='flash_attn')
8
+ print(model)
ms-swift/tests/test_align/test_rlhf_loss.py ADDED
File without changes
ms-swift/tests/test_align/test_template/test_agent.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['SWIFT_DEBUG'] = '1'
4
+
5
+ system = 'You are a helpful assistant.'
6
+
7
+ tools = [{
8
+ 'type': 'function',
9
+ 'function': {
10
+ 'name': 'get_current_weather',
11
+ 'description': 'Get the current weather in a given location',
12
+ 'parameters': {
13
+ 'type': 'object',
14
+ 'properties': {
15
+ 'location': {
16
+ 'type': 'string',
17
+ 'description': 'The city and state, e.g. San Francisco, CA'
18
+ },
19
+ 'unit': {
20
+ 'type': 'string',
21
+ 'enum': ['celsius', 'fahrenheit']
22
+ }
23
+ },
24
+ 'required': ['location']
25
+ }
26
+ }
27
+ }, {
28
+ 'name_for_model': 'tool2',
29
+ 'name_for_human': '工具2',
30
+ 'description': 'Tool2的描述',
31
+ }]
32
+
33
+ glm4_tools = [{
34
+ 'type': 'function',
35
+ 'function': {
36
+ 'name': 'realtime_aqi',
37
+ 'description': '天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息',
38
+ 'parameters': {
39
+ 'type': 'object',
40
+ 'properties': {
41
+ 'city': {
42
+ 'description': '城市名'
43
+ }
44
+ },
45
+ 'required': ['city']
46
+ }
47
+ }
48
+ }]
49
+ glm4_tool_messasges = [
50
+ {
51
+ 'role': 'tool',
52
+ 'content': '{"city": "北京", "aqi": "10", "unit": "celsius"}'
53
+ },
54
+ {
55
+ 'role': 'tool',
56
+ 'content': '{"city": "上海", "aqi": "72", "unit": "fahrenheit"}'
57
+ },
58
+ ]
59
+ glm4_query = '北京和上海今天的天气情况'
60
+
61
+
62
+ def _infer(engine, num_tools: int = 1, agent_tools=None, tool_messages=None, query=None):
63
+ if agent_tools is None:
64
+ agent_tools = tools
65
+ if tool_messages is None:
66
+ tool_messages = []
67
+ for _ in range(num_tools):
68
+ tool_messages.append({
69
+ 'role': 'tool',
70
+ 'content': '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
71
+ })
72
+ stop = [engine.default_template.agent_template.keyword.observation]
73
+ query = query or "How's the weather in Beijing today?"
74
+ infer_request = InferRequest([{'role': 'user', 'content': query}], tools=agent_tools)
75
+ request_config = RequestConfig(max_tokens=512, stop=stop, temperature=0)
76
+ resp_list = engine.infer([infer_request], request_config=request_config)
77
+ response = resp_list[0].choices[0].message.content
78
+ toolcall = resp_list[0].choices[0].message.tool_calls[0].function
79
+ print(f'response: {response}')
80
+ print(f'toolcall: {toolcall}')
81
+ assert toolcall is not None
82
+ infer_request.messages.append({'role': 'assistant', 'content': response})
83
+ infer_request.messages += tool_messages
84
+ resp_list = engine.infer([infer_request], request_config=request_config)
85
+ response2 = resp_list[0].choices[0].message.content
86
+ print(f'response2: {response2}')
87
+ infer_request.messages.append({'role': 'assistant', 'content': response2})
88
+ return infer_request.messages
89
+
90
+
91
+ def test_react_en():
92
+ agent_template = agent_templates['react_en']()
93
+ new_system = agent_template._format_tools(tools, system)
94
+ assert len(new_system) == 1144
95
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
96
+ template = engine.default_template
97
+ template.agent_template = agent_template
98
+ messages = _infer(engine)
99
+ assert messages[-1]['content'] == (
100
+ 'Thought: The current temperature in Beijing is 32 degrees Celsius, and the condition is sunny '
101
+ 'with a humidity of 50%.\nFinal Answer: The current temperature in Beijing is 32 degrees Celsius,'
102
+ ' and the condition is sunny with a humidity of 50%.')
103
+ template.set_mode('train')
104
+ encoded = template.encode({'messages': messages})
105
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
106
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
107
+
108
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
109
+ data = dataset[6]
110
+ data['messages'].insert(1, data['messages'][1])
111
+ data['messages'].insert(3, data['messages'][3])
112
+ template.template_backend = 'swift'
113
+ encoded = template.encode(data)
114
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
115
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
116
+
117
+
118
+ def test_react_zh():
119
+ agent_template = agent_templates['react_zh']()
120
+ new_system = agent_template._format_tools(tools, system)
121
+ assert len(new_system) == 712
122
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
123
+ template = engine.default_template
124
+ template.agent_template = agent_template
125
+ _infer(engine)
126
+
127
+
128
+ def test_qwen_en():
129
+ agent_template = agent_templates['qwen_en']()
130
+ new_system = agent_template._format_tools(tools, system)
131
+ assert len(new_system) == 879
132
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
133
+ template = engine.default_template
134
+ template.agent_template = agent_template
135
+ messages = _infer(engine)
136
+ assert messages[-1]['content'] == (
137
+ '✿RETURN✿: Today in Beijing, the temperature is 32°C with sunny conditions and the humidity '
138
+ 'is at 50%. Enjoy the nice weather!')
139
+ template.set_mode('train')
140
+ encoded = template.encode({'messages': messages})
141
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
142
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
143
+
144
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
145
+ data = dataset[6]
146
+ data['messages'].insert(1, data['messages'][1])
147
+ data['messages'].insert(3, data['messages'][3])
148
+ template.template_backend = 'swift'
149
+ encoded = template.encode(data)
150
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
151
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
152
+
153
+
154
+ def test_qwen_zh():
155
+ agent_template = agent_templates['qwen_zh']()
156
+ new_system = agent_template._format_tools(tools, system)
157
+ assert len(new_system) == 577
158
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
159
+ template = engine.default_template
160
+ template.agent_template = agent_template
161
+ _infer(engine)
162
+
163
+
164
+ def test_qwen_en_parallel():
165
+ agent_template = agent_templates['qwen_en_parallel']()
166
+ new_system = agent_template._format_tools(tools, system)
167
+ assert len(new_system) == 1012
168
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
169
+ template = engine.default_template
170
+ template.agent_template = agent_template
171
+ messages = _infer(engine, num_tools=2)
172
+ assert messages[-1]['content'] == (
173
+ '✿RETURN✿: Today in Beijing, the temperature is 32 degrees Celsius with sunny conditions '
174
+ 'and the humidity is at 50%. Enjoy the nice weather!')
175
+ template.set_mode('train')
176
+ encoded = template.encode({'messages': messages})
177
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
178
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
179
+
180
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
181
+ data = dataset[6]
182
+ data['messages'].insert(1, data['messages'][1])
183
+ data['messages'].insert(3, data['messages'][3])
184
+ template.template_backend = 'swift'
185
+ encoded = template.encode(data)
186
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
187
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
188
+
189
+
190
+ def test_qwen_zh_parallel():
191
+ agent_template = agent_templates['qwen_zh_parallel']()
192
+ new_system = agent_template._format_tools(tools, system)
193
+ assert len(new_system) == 688
194
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
195
+ template = engine.default_template
196
+ template.agent_template = agent_template
197
+ _infer(engine, num_tools=2)
198
+
199
+
200
+ def test_hermes():
201
+ agent_template = agent_templates['hermes']()
202
+ new_system = agent_template._format_tools(tools, system)
203
+ assert len(new_system) == 875
204
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
205
+ template = engine.default_template
206
+ template.agent_template = agent_template
207
+ messages = _infer(engine, num_tools=2)
208
+ template.template_backend = 'jinja'
209
+ messages2 = _infer(engine, num_tools=2)
210
+ assert messages[-1]['content'] == messages2[-1]['content'] == (
211
+ 'Today in Beijing, the temperature is 32 degrees Celsius with sunny conditions '
212
+ 'and the humidity is at 50%. Enjoy the nice weather!')
213
+ template.set_mode('train')
214
+ encoded = template.encode({'messages': messages})
215
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
216
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
217
+
218
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
219
+ data = dataset[6]
220
+ data['messages'].insert(1, data['messages'][1])
221
+ data['messages'].insert(3, data['messages'][3])
222
+ template.template_backend = 'swift'
223
+ encoded = template.encode(data)
224
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
225
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
226
+ template.template_backend = 'jinja'
227
+ encoded2 = template.encode(data)
228
+ print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
229
+ print(f'labels: {template.safe_decode(encoded2["labels"])}')
230
+ assert encoded['input_ids'] == encoded2['input_ids'][:-1]
231
+
232
+
233
+ def test_toolbench():
234
+ agent_template = agent_templates['toolbench']()
235
+ new_system = agent_template._format_tools(tools, system)
236
+ assert len(new_system) == 1833
237
+ engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
238
+ template = engine.default_template
239
+ template.agent_template = agent_template
240
+ _infer(engine)
241
+
242
+
243
+ def test_glm4():
244
+ agent_template = agent_templates['glm4']()
245
+ new_system = agent_template._format_tools(tools, system)
246
+ assert len(new_system) == 846
247
+ engine = PtEngine('ZhipuAI/glm-4-9b-chat')
248
+ template = engine.default_template
249
+ template.agent_template = agent_template
250
+ _infer(engine, agent_tools=glm4_tools, tool_messages=glm4_tool_messasges, query=glm4_query)
251
+
252
+
253
+ def test_glm4_0414():
254
+ agent_template = agent_templates['glm4_0414']()
255
+ new_system = agent_template._format_tools(tools, system)
256
+ assert len(new_system) == 769
257
+ engine = PtEngine('ZhipuAI/GLM-4-9B-0414')
258
+ template = engine.default_template
259
+ template.agent_template = agent_template
260
+ messages = _infer(engine, agent_tools=glm4_tools, tool_messages=glm4_tool_messasges, query=glm4_query)
261
+ assert messages[-1]['content'] == '根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。'
262
+ template.set_mode('train')
263
+ encoded = template.encode({'messages': messages})
264
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
265
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
266
+
267
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
268
+ data = dataset[6]
269
+ data['messages'].insert(1, data['messages'][1])
270
+ data['messages'].insert(3, data['messages'][3])
271
+ template.template_backend = 'swift'
272
+ encoded = template.encode(data)
273
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
274
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
275
+
276
+
277
+ def test_llama3():
278
+ agent_template = agent_templates['llama3']()
279
+ engine = PtEngine('LLM-Research/Llama-3.2-3B-Instruct')
280
+ template = engine.default_template
281
+ template.agent_template = agent_template
282
+ messages = _infer(engine)
283
+
284
+ template.set_mode('train')
285
+ encoded = template.encode({'messages': messages})
286
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
287
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
288
+
289
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
290
+ data = dataset[6]
291
+ data['messages'].insert(1, data['messages'][1])
292
+ data['messages'].insert(3, data['messages'][3])
293
+ template.template_backend = 'swift'
294
+ encoded = template.encode(data)
295
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
296
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
297
+
298
+
299
+ def test_llama4():
300
+ agent_template = agent_templates['llama4']()
301
+ engine = PtEngine('LLM-Research/Llama-4-Scout-17B-16E-Instruct')
302
+ template = engine.default_template
303
+ template.agent_template = agent_template
304
+ messages = _infer(engine)
305
+ template.set_mode('train')
306
+ encoded = template.encode({'messages': messages})
307
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
308
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
309
+
310
+
311
+ if __name__ == '__main__':
312
+ from swift.plugin import agent_templates
313
+ from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
314
+ # test_react_en()
315
+ # test_react_zh()
316
+ # test_qwen_en()
317
+ # test_qwen_zh()
318
+ # test_qwen_en_parallel()
319
+ # test_qwen_zh_parallel()
320
+ test_hermes()
321
+ # test_toolbench()
322
+ # test_glm4()
323
+ # test_glm4_0414()
324
+ # test_llama3()
325
+ # test_llama4()
ms-swift/tests/test_align/test_template/test_audio.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
4
+
5
+
6
+ def _infer_model(pt_engine, system=None, messages=None, audios=None):
7
+ seed_everything(42)
8
+ request_config = RequestConfig(max_tokens=128, temperature=0)
9
+ if messages is None:
10
+ messages = []
11
+ if system is not None:
12
+ messages += [{'role': 'system', 'content': system}]
13
+ messages += [{'role': 'user', 'content': '你好'}]
14
+ resp = pt_engine.infer([{'messages': messages}], request_config=request_config)
15
+ response = resp[0].choices[0].message.content
16
+ messages += [{'role': 'assistant', 'content': response}]
17
+ messages += [{'role': 'user', 'content': '<audio>这段语音说了什么'}]
18
+ else:
19
+ messages = messages.copy()
20
+ if audios is None:
21
+ audios = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav']
22
+ resp = pt_engine.infer([{'messages': messages, 'audios': audios}], request_config=request_config)
23
+ response = resp[0].choices[0].message.content
24
+ messages += [{'role': 'assistant', 'content': response}]
25
+ logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}')
26
+ return response
27
+
28
+
29
+ def test_qwen_audio():
30
+ pt_engine = PtEngine('Qwen/Qwen-Audio-Chat')
31
+ _infer_model(pt_engine)
32
+
33
+
34
+ def test_qwen2_audio():
35
+ # transformers==4.48.3
36
+ pt_engine = PtEngine('Qwen/Qwen2-Audio-7B-Instruct')
37
+ messages = [{'role': 'user', 'content': '<audio>'}]
38
+ audios = ['https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav']
39
+ response = _infer_model(pt_engine, messages=messages, audios=audios)
40
+ pt_engine.default_template.template_backend = 'jinja'
41
+ response2 = _infer_model(pt_engine, messages=messages, audios=audios)
42
+ assert response == response2 == 'Yes, the speaker is female and in her twenties.'
43
+
44
+
45
+ def test_xcomposer2d5_ol():
46
+ pt_engine = PtEngine('Shanghai_AI_Laboratory/internlm-xcomposer2d5-ol-7b:audio')
47
+ _infer_model(pt_engine)
48
+ pt_engine.default_template.template_backend = 'jinja'
49
+ _infer_model(pt_engine)
50
+
51
+
52
+ def test_step_audio_chat():
53
+ pt_engine = PtEngine('stepfun-ai/Step-Audio-Chat')
54
+ response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<audio>'}])
55
+ assert response == ('是的呢,今天天气晴朗,阳光明媚,微风和煦,非常适合外出活动。天空湛蓝,白云朵朵,让人心情愉悦。希望你能好好享受这美好的一天!')
56
+
57
+
58
+ def test_qwen2_5_omni():
59
+ USE_AUDIO_IN_VIDEO = True
60
+ os.environ['USE_AUDIO_IN_VIDEO'] = str(USE_AUDIO_IN_VIDEO)
61
+ pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B')
62
+ response = _infer_model(pt_engine)
63
+ pt_engine.default_template.template_backend = 'jinja'
64
+ response2 = _infer_model(pt_engine)
65
+ assert response == response2
66
+
67
+
68
+ if __name__ == '__main__':
69
+ from swift.llm import PtEngine, RequestConfig
70
+ from swift.utils import get_logger, seed_everything
71
+ logger = get_logger()
72
+ # test_qwen_audio()
73
+ # test_qwen2_audio()
74
+ # test_xcomposer2d5_ol()
75
+ # test_step_audio_chat()
76
+ test_qwen2_5_omni()