Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- ms-swift/.ipynb_checkpoints/README-checkpoint.md +423 -0
- ms-swift/.ipynb_checkpoints/README_CN-checkpoint.md +413 -0
- ms-swift/.ipynb_checkpoints/dataset-checkpoint.json +60 -0
- ms-swift/.ipynb_checkpoints/dataset_overlap5s716_gemini-checkpoint.json +0 -0
- ms-swift/.ipynb_checkpoints/gen_data-checkpoint.py +154 -0
- ms-swift/.ipynb_checkpoints/overlap5s716_gemini-checkpoint.json +0 -0
- ms-swift/.ipynb_checkpoints/setup-checkpoint.py +165 -0
- ms-swift/.ipynb_checkpoints/test-checkpoint.sh +6 -0
- ms-swift/.ipynb_checkpoints/train-checkpoint.sh +42 -0
- ms-swift/asset/banner.png +3 -0
- ms-swift/docs/resources/dpo_data.png +3 -0
- ms-swift/docs/resources/grpo_clevr_count.png +3 -0
- ms-swift/docs/resources/grpo_code.png +3 -0
- ms-swift/docs/resources/grpo_countdown.png +3 -0
- ms-swift/docs/resources/grpo_countdown_1.png +3 -0
- ms-swift/docs/resources/grpo_geoqa.png +3 -0
- ms-swift/docs/resources/grpo_openr1_multimodal.png +3 -0
- ms-swift/docs/resources/kto_data.png +3 -0
- ms-swift/docs/resources/web-ui-en.jpg +3 -0
- ms-swift/docs/resources/web-ui.jpg +3 -0
- ms-swift/silence_overlaps.zip +3 -0
- ms-swift/silence_overlaps/.ipynb_checkpoints/clean_wrong-checkpoint.py +73 -0
- ms-swift/silence_overlaps/.ipynb_checkpoints/cleaned_transcriptions-checkpoint.json +0 -0
- ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript-checkpoint.json +0 -0
- ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript2-checkpoint.json +1 -0
- ms-swift/silence_overlaps/700/merge_and_shuffle_json.py +61 -0
- ms-swift/silence_overlaps/700/split_train_test.py +36 -0
- ms-swift/silence_overlaps/700/train/overlap5s_isoverlap_train.json +0 -0
- ms-swift/silence_overlaps/700/train/overlap5s_speaker_segments_train.json +0 -0
- ms-swift/silence_overlaps/clean_wrong.py +73 -0
- ms-swift/silence_overlaps/cleaned_transcriptions.json +0 -0
- ms-swift/silence_overlaps/overlap5s_isoverlap.json +0 -0
- ms-swift/silence_overlaps/overlap5s_speaker_segments.json +0 -0
- ms-swift/silence_overlaps/overlap5s_transcriptions.json +0 -0
- ms-swift/silence_overlaps/silence_isoverlaps.json +0 -0
- ms-swift/silence_overlaps/silence_issilence.json +0 -0
- ms-swift/silence_overlaps/transcriptions.json +0 -0
- ms-swift/swift/ui/llm_train/utils.py +37 -0
- ms-swift/swift/utils/__pycache__/logger.cpython-310.pyc +0 -0
- ms-swift/swift/utils/__pycache__/torch_utils.cpython-310.pyc +0 -0
- ms-swift/swift/utils/__pycache__/utils.cpython-310.pyc +0 -0
- ms-swift/swift/utils/utils.py +323 -0
- ms-swift/tests/__init__.py +0 -0
- ms-swift/tests/app/test_app.py +25 -0
- ms-swift/tests/llm/data/multi_modal_1.jsonl +3 -0
- ms-swift/tests/models/test_flash_attn.py +8 -0
- ms-swift/tests/test_align/test_rlhf_loss.py +0 -0
- ms-swift/tests/test_align/test_template/test_agent.py +325 -0
- ms-swift/tests/test_align/test_template/test_audio.py +76 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ms-swift/docs/resources/web-ui.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ms-swift/docs/resources/grpo_code.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ms-swift/docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
ms-swift/asset/banner.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
ms-swift/docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
ms-swift/docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
ms-swift/docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
ms-swift/docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
ms-swift/docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
ms-swift/docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
ms-swift/docs/resources/grpo_countdown.png filter=lfs diff=lfs merge=lfs -text
|
ms-swift/.ipynb_checkpoints/README-checkpoint.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">ModelScope Community Website</a>
|
| 10 |
+
<br>
|
| 11 |
+
<a href="README_CN.md">中文</a>   |   English  
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 16 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 17 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 18 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 19 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 20 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 21 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
<p align="center">
|
| 25 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
<p align="center">
|
| 29 |
+
<a href="https://arxiv.org/abs/2408.05517">Paper</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 30 |
+
</p>
|
| 31 |
+
|
| 32 |
+
## 📖 Table of Contents
|
| 33 |
+
- [Groups](#-Groups)
|
| 34 |
+
- [Introduction](#-introduction)
|
| 35 |
+
- [News](#-news)
|
| 36 |
+
- [Installation](#%EF%B8%8F-installation)
|
| 37 |
+
- [Quick Start](#-quick-Start)
|
| 38 |
+
- [Usage](#-Usage)
|
| 39 |
+
- [License](#-License)
|
| 40 |
+
- [Citation](#-citation)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## ☎ Groups
|
| 44 |
+
|
| 45 |
+
You can contact us and communicate with us by adding our group:
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
|
| 49 |
+
:-------------------------:|:-------------------------:
|
| 50 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## 📝 Introduction
|
| 54 |
+
🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
|
| 55 |
+
|
| 56 |
+
🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
|
| 57 |
+
|
| 58 |
+
**Why choose ms-swift?**
|
| 59 |
+
|
| 60 |
+
- 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
|
| 61 |
+
- **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
|
| 62 |
+
- **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
|
| 63 |
+
- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
|
| 64 |
+
- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
|
| 65 |
+
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
|
| 66 |
+
- **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
|
| 67 |
+
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
|
| 68 |
+
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
|
| 69 |
+
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
|
| 70 |
+
- 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
|
| 71 |
+
- **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
|
| 72 |
+
- **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
|
| 73 |
+
- **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## 🎉 News
|
| 77 |
+
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
|
| 78 |
+
- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
|
| 79 |
+
- 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
|
| 80 |
+
- 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
|
| 81 |
+
- 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
|
| 82 |
+
- 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
|
| 83 |
+
- 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
|
| 84 |
+
- 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
|
| 85 |
+
- 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
|
| 86 |
+
- 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
|
| 87 |
+
<details><summary>More</summary>
|
| 88 |
+
|
| 89 |
+
- 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
|
| 90 |
+
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
|
| 91 |
+
- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
|
| 92 |
+
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
|
| 93 |
+
- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
## 🛠️ Installation
|
| 97 |
+
To install using pip:
|
| 98 |
+
```shell
|
| 99 |
+
pip install ms-swift -U
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
To install from source:
|
| 103 |
+
```shell
|
| 104 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 105 |
+
|
| 106 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 107 |
+
cd ms-swift
|
| 108 |
+
pip install -e .
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Running Environment:
|
| 112 |
+
|
| 113 |
+
| | Range | Recommended | Notes |
|
| 114 |
+
| ------------ |--------------| ----------- | ----------------------------------------- |
|
| 115 |
+
| python | >=3.9 | 3.10 | |
|
| 116 |
+
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
|
| 117 |
+
| torch | >=2.0 | | |
|
| 118 |
+
| transformers | >=4.33 | 4.51 | |
|
| 119 |
+
| modelscope | >=1.23 | | |
|
| 120 |
+
| peft | >=0.11,<0.16 | ||
|
| 121 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 122 |
+
| deepspeed | >=0.14 | 0.14.5 | Training |
|
| 123 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
|
| 124 |
+
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
|
| 125 |
+
| evalscope | >=0.11 | | Evaluation |
|
| 126 |
+
|
| 127 |
+
For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
## 🚀 Quick Start
|
| 131 |
+
|
| 132 |
+
10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
|
| 133 |
+
|
| 134 |
+
### Command Line Interface
|
| 135 |
+
|
| 136 |
+
```shell
|
| 137 |
+
# 22GB
|
| 138 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 139 |
+
swift sft \
|
| 140 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 141 |
+
--train_type lora \
|
| 142 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 143 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 144 |
+
'swift/self-cognition#500' \
|
| 145 |
+
--torch_dtype bfloat16 \
|
| 146 |
+
--num_train_epochs 1 \
|
| 147 |
+
--per_device_train_batch_size 1 \
|
| 148 |
+
--per_device_eval_batch_size 1 \
|
| 149 |
+
--learning_rate 1e-4 \
|
| 150 |
+
--lora_rank 8 \
|
| 151 |
+
--lora_alpha 32 \
|
| 152 |
+
--target_modules all-linear \
|
| 153 |
+
--gradient_accumulation_steps 16 \
|
| 154 |
+
--eval_steps 50 \
|
| 155 |
+
--save_steps 50 \
|
| 156 |
+
--save_total_limit 2 \
|
| 157 |
+
--logging_steps 5 \
|
| 158 |
+
--max_length 2048 \
|
| 159 |
+
--output_dir output \
|
| 160 |
+
--system 'You are a helpful assistant.' \
|
| 161 |
+
--warmup_ratio 0.05 \
|
| 162 |
+
--dataloader_num_workers 4 \
|
| 163 |
+
--model_author swift \
|
| 164 |
+
--model_name swift-robot
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
Tips:
|
| 168 |
+
|
| 169 |
+
- If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
|
| 170 |
+
- The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
|
| 171 |
+
- To train with a different model, simply modify `--model <model_id/model_path>`.
|
| 172 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 173 |
+
|
| 174 |
+
After training is complete, use the following command to infer with the trained weights:
|
| 175 |
+
|
| 176 |
+
- Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
|
| 177 |
+
|
| 178 |
+
```shell
|
| 179 |
+
# Using an interactive command line for inference.
|
| 180 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 181 |
+
swift infer \
|
| 182 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 183 |
+
--stream true \
|
| 184 |
+
--temperature 0 \
|
| 185 |
+
--max_new_tokens 2048
|
| 186 |
+
|
| 187 |
+
# merge-lora and use vLLM for inference acceleration
|
| 188 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 189 |
+
swift infer \
|
| 190 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 191 |
+
--stream true \
|
| 192 |
+
--merge_lora true \
|
| 193 |
+
--infer_backend vllm \
|
| 194 |
+
--max_model_len 8192 \
|
| 195 |
+
--temperature 0 \
|
| 196 |
+
--max_new_tokens 2048
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Finally, use the following command to push the model to ModelScope:
|
| 200 |
+
|
| 201 |
+
```shell
|
| 202 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 203 |
+
swift export \
|
| 204 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 205 |
+
--push_to_hub true \
|
| 206 |
+
--hub_model_id '<your-model-id>' \
|
| 207 |
+
--hub_token '<your-sdk-token>' \
|
| 208 |
+
--use_hf false
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
### Web-UI
|
| 213 |
+
The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
|
| 214 |
+
|
| 215 |
+
```shell
|
| 216 |
+
SWIFT_UI_LANG=en swift web-ui
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+

|
| 220 |
+
|
| 221 |
+
### Using Python
|
| 222 |
+
|
| 223 |
+
ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
|
| 224 |
+
|
| 225 |
+
Training:
|
| 226 |
+
|
| 227 |
+
```python
|
| 228 |
+
# Retrieve the model and template, and add a trainable LoRA module
|
| 229 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 230 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 231 |
+
model = Swift.prepare_model(model, lora_config)
|
| 232 |
+
|
| 233 |
+
# Download and load the dataset, and encode the text into tokens
|
| 234 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 235 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 236 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 237 |
+
|
| 238 |
+
# Train the model
|
| 239 |
+
trainer = Seq2SeqTrainer(
|
| 240 |
+
model=model,
|
| 241 |
+
args=training_args,
|
| 242 |
+
data_collator=template.data_collator,
|
| 243 |
+
train_dataset=train_dataset,
|
| 244 |
+
eval_dataset=val_dataset,
|
| 245 |
+
template=template,
|
| 246 |
+
)
|
| 247 |
+
trainer.train()
|
| 248 |
+
```
|
| 249 |
+
Inference:
|
| 250 |
+
|
| 251 |
+
```python
|
| 252 |
+
# Perform inference using the native PyTorch engine
|
| 253 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 254 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 255 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 256 |
+
|
| 257 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 258 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
## ✨ Usage
|
| 262 |
+
Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
|
| 263 |
+
|
| 264 |
+
- If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
|
| 265 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 266 |
+
|
| 267 |
+
| Useful Links |
|
| 268 |
+
| ------ |
|
| 269 |
+
| [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
|
| 270 |
+
| [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
|
| 271 |
+
| [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
|
| 272 |
+
| [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 273 |
+
|
| 274 |
+
### Training
|
| 275 |
+
|
| 276 |
+
Supported Training Methods:
|
| 277 |
+
|
| 278 |
+
| Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
|
| 279 |
+
|------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
|
| 280 |
+
| Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 281 |
+
| Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 282 |
+
| DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 283 |
+
| GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 284 |
+
| Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 285 |
+
| PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 286 |
+
| KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 287 |
+
| CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 288 |
+
| SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 289 |
+
| ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 290 |
+
| Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 291 |
+
| Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
Pre-training:
|
| 296 |
+
```shell
|
| 297 |
+
# 8*A100
|
| 298 |
+
NPROC_PER_NODE=8 \
|
| 299 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 300 |
+
swift pt \
|
| 301 |
+
--model Qwen/Qwen2.5-7B \
|
| 302 |
+
--dataset swift/chinese-c4 \
|
| 303 |
+
--streaming true \
|
| 304 |
+
--train_type full \
|
| 305 |
+
--deepspeed zero2 \
|
| 306 |
+
--output_dir output \
|
| 307 |
+
--max_steps 10000 \
|
| 308 |
+
...
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
Fine-tuning:
|
| 312 |
+
```shell
|
| 313 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 314 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 315 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-en \
|
| 316 |
+
--train_type lora \
|
| 317 |
+
--output_dir output \
|
| 318 |
+
...
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
RLHF:
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 324 |
+
--rlhf_type dpo \
|
| 325 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 326 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 327 |
+
--train_type lora \
|
| 328 |
+
--output_dir output \
|
| 329 |
+
...
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### Inference
|
| 334 |
+
```shell
|
| 335 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 336 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 337 |
+
--stream true \
|
| 338 |
+
--infer_backend pt \
|
| 339 |
+
--max_new_tokens 2048
|
| 340 |
+
|
| 341 |
+
# LoRA
|
| 342 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 343 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 344 |
+
--adapters swift/test_lora \
|
| 345 |
+
--stream true \
|
| 346 |
+
--infer_backend pt \
|
| 347 |
+
--temperature 0 \
|
| 348 |
+
--max_new_tokens 2048
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
### Interface Inference
|
| 352 |
+
```shell
|
| 353 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 354 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 355 |
+
--stream true \
|
| 356 |
+
--infer_backend pt \
|
| 357 |
+
--max_new_tokens 2048
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
### Deployment
|
| 361 |
+
```shell
|
| 362 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 363 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 364 |
+
--infer_backend vllm
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
### Sampling
|
| 368 |
+
```shell
|
| 369 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 370 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 371 |
+
--sampler_engine pt \
|
| 372 |
+
--num_return_sequences 5 \
|
| 373 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
### Evaluation
|
| 377 |
+
```shell
|
| 378 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 379 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 380 |
+
--infer_backend lmdeploy \
|
| 381 |
+
--eval_backend OpenCompass \
|
| 382 |
+
--eval_dataset ARC_c
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
### Quantization
|
| 386 |
+
```shell
|
| 387 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 388 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 389 |
+
--quant_bits 4 --quant_method awq \
|
| 390 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 391 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
### Push Model
|
| 395 |
+
```shell
|
| 396 |
+
swift export \
|
| 397 |
+
--model <model-path> \
|
| 398 |
+
--push_to_hub true \
|
| 399 |
+
--hub_model_id '<model-id>' \
|
| 400 |
+
--hub_token '<sdk-token>'
|
| 401 |
+
```
|
| 402 |
+
|
| 403 |
+
## 🏛 License
|
| 404 |
+
|
| 405 |
+
This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
|
| 406 |
+
|
| 407 |
+
## 📎 Citation
|
| 408 |
+
|
| 409 |
+
```bibtex
|
| 410 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 411 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 412 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 413 |
+
year={2024},
|
| 414 |
+
eprint={2408.05517},
|
| 415 |
+
archivePrefix={arXiv},
|
| 416 |
+
primaryClass={cs.CL},
|
| 417 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 418 |
+
}
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
## Star History
|
| 422 |
+
|
| 423 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
ms-swift/.ipynb_checkpoints/README_CN-checkpoint.md
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">魔搭社区官网</a>
|
| 10 |
+
<br>
|
| 11 |
+
中文  |  <a href="README.md">English</a> 
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 17 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 18 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 19 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 20 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 21 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 22 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 27 |
+
</p>
|
| 28 |
+
|
| 29 |
+
<p align="center">
|
| 30 |
+
<a href="https://arxiv.org/abs/2408.05517">论文</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 31 |
+
</p>
|
| 32 |
+
|
| 33 |
+
## 📖 目录
|
| 34 |
+
- [用户群](#-用户群)
|
| 35 |
+
- [简介](#-简介)
|
| 36 |
+
- [新闻](#-新闻)
|
| 37 |
+
- [安装](#%EF%B8%8F-安装)
|
| 38 |
+
- [快速开始](#-快速开始)
|
| 39 |
+
- [如何使用](#-如何使用)
|
| 40 |
+
- [License](#-license)
|
| 41 |
+
- [引用](#-引用)
|
| 42 |
+
|
| 43 |
+
## ☎ 用户群
|
| 44 |
+
|
| 45 |
+
请扫描下面的二维码来加入我们的交流群:
|
| 46 |
+
|
| 47 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
|
| 48 |
+
:-------------------------:|:-------------------------:
|
| 49 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 50 |
+
|
| 51 |
+
## 📝 简介
|
| 52 |
+
🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
|
| 53 |
+
|
| 54 |
+
🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
|
| 55 |
+
|
| 56 |
+
**为什么选择ms-swift?**
|
| 57 |
+
- 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
|
| 58 |
+
- **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
|
| 59 |
+
- **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
|
| 60 |
+
- 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
|
| 61 |
+
- **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
|
| 62 |
+
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
|
| 63 |
+
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
|
| 64 |
+
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
|
| 65 |
+
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
|
| 66 |
+
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
|
| 67 |
+
- 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
|
| 68 |
+
- **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
|
| 69 |
+
- **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
|
| 70 |
+
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
|
| 71 |
+
|
| 72 |
+
## 🎉 新闻
|
| 73 |
+
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
|
| 74 |
+
- 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
|
| 75 |
+
- 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
|
| 76 |
+
- 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
|
| 77 |
+
- 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
|
| 78 |
+
- 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
|
| 79 |
+
- 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
|
| 80 |
+
- 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
|
| 81 |
+
- 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
|
| 82 |
+
- 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
|
| 83 |
+
<details><summary>更多</summary>
|
| 84 |
+
|
| 85 |
+
- 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
|
| 86 |
+
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
|
| 87 |
+
- 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
|
| 88 |
+
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
|
| 89 |
+
- 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
|
| 90 |
+
</details>
|
| 91 |
+
|
| 92 |
+
## 🛠️ 安装
|
| 93 |
+
使用pip进行安装:
|
| 94 |
+
```shell
|
| 95 |
+
pip install ms-swift -U
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
从源代码安装:
|
| 99 |
+
```shell
|
| 100 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 101 |
+
|
| 102 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 103 |
+
cd ms-swift
|
| 104 |
+
pip install -e .
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
运行环境:
|
| 108 |
+
|
| 109 |
+
| | 范围 | 推荐 | 备注 |
|
| 110 |
+
| ------ |--------------| ---- | --|
|
| 111 |
+
| python | >=3.9 | 3.10 ||
|
| 112 |
+
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
|
| 113 |
+
| torch | >=2.0 | ||
|
| 114 |
+
| transformers | >=4.33 | 4.51 ||
|
| 115 |
+
| modelscope | >=1.23 | ||
|
| 116 |
+
| peft | >=0.11,<0.16 | ||
|
| 117 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 118 |
+
| deepspeed | >=0.14 | 0.14.5 |训练|
|
| 119 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
|
| 120 |
+
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
|
| 121 |
+
| evalscope | >=0.11 | |评测|
|
| 122 |
+
|
| 123 |
+
更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## 🚀 快速开始
|
| 127 |
+
|
| 128 |
+
**10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
|
| 129 |
+
|
| 130 |
+
### 命令行
|
| 131 |
+
```shell
|
| 132 |
+
# 22GB
|
| 133 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 134 |
+
swift sft \
|
| 135 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 136 |
+
--train_type lora \
|
| 137 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 138 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 139 |
+
'swift/self-cognition#500' \
|
| 140 |
+
--torch_dtype bfloat16 \
|
| 141 |
+
--num_train_epochs 1 \
|
| 142 |
+
--per_device_train_batch_size 1 \
|
| 143 |
+
--per_device_eval_batch_size 1 \
|
| 144 |
+
--learning_rate 1e-4 \
|
| 145 |
+
--lora_rank 8 \
|
| 146 |
+
--lora_alpha 32 \
|
| 147 |
+
--target_modules all-linear \
|
| 148 |
+
--gradient_accumulation_steps 16 \
|
| 149 |
+
--eval_steps 50 \
|
| 150 |
+
--save_steps 50 \
|
| 151 |
+
--save_total_limit 2 \
|
| 152 |
+
--logging_steps 5 \
|
| 153 |
+
--max_length 2048 \
|
| 154 |
+
--output_dir output \
|
| 155 |
+
--system 'You are a helpful assistant.' \
|
| 156 |
+
--warmup_ratio 0.05 \
|
| 157 |
+
--dataloader_num_workers 4 \
|
| 158 |
+
--model_author swift \
|
| 159 |
+
--model_name swift-robot
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
小贴士:
|
| 163 |
+
- 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
|
| 164 |
+
- `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
|
| 165 |
+
- 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
|
| 166 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 167 |
+
|
| 168 |
+
训练完成后,使用以下命令对训练后的权重进行推理:
|
| 169 |
+
- 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
|
| 170 |
+
|
| 171 |
+
```shell
|
| 172 |
+
# 使用交互式命令行进行推理
|
| 173 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 174 |
+
swift infer \
|
| 175 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 176 |
+
--stream true \
|
| 177 |
+
--temperature 0 \
|
| 178 |
+
--max_new_tokens 2048
|
| 179 |
+
|
| 180 |
+
# merge-lora并使用vLLM进行推理加速
|
| 181 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 182 |
+
swift infer \
|
| 183 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 184 |
+
--stream true \
|
| 185 |
+
--merge_lora true \
|
| 186 |
+
--infer_backend vllm \
|
| 187 |
+
--max_model_len 8192 \
|
| 188 |
+
--temperature 0 \
|
| 189 |
+
--max_new_tokens 2048
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
最后,使用以下命令将模型推送到ModelScope:
|
| 193 |
+
```shell
|
| 194 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 195 |
+
swift export \
|
| 196 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 197 |
+
--push_to_hub true \
|
| 198 |
+
--hub_model_id '<your-model-id>' \
|
| 199 |
+
--hub_token '<your-sdk-token>' \
|
| 200 |
+
--use_hf false
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### Web-UI
|
| 204 |
+
|
| 205 |
+
Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
|
| 206 |
+
|
| 207 |
+
```shell
|
| 208 |
+
swift web-ui
|
| 209 |
+
```
|
| 210 |
+

|
| 211 |
+
|
| 212 |
+
### 使用Python
|
| 213 |
+
ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
|
| 214 |
+
|
| 215 |
+
训练:
|
| 216 |
+
```python
|
| 217 |
+
# 获取模型和template,并加入可训练的LoRA模块
|
| 218 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 219 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 220 |
+
model = Swift.prepare_model(model, lora_config)
|
| 221 |
+
|
| 222 |
+
# 下载并载入数据集,并将文本encode成tokens
|
| 223 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 224 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 225 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 226 |
+
|
| 227 |
+
# 进行训练
|
| 228 |
+
trainer = Seq2SeqTrainer(
|
| 229 |
+
model=model,
|
| 230 |
+
args=training_args,
|
| 231 |
+
data_collator=template.data_collator,
|
| 232 |
+
train_dataset=train_dataset,
|
| 233 |
+
eval_dataset=val_dataset,
|
| 234 |
+
template=template,
|
| 235 |
+
)
|
| 236 |
+
trainer.train()
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
推理:
|
| 240 |
+
```python
|
| 241 |
+
# 使用原生pytorch引擎进行推理
|
| 242 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 243 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 244 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 245 |
+
|
| 246 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 247 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## ✨ 如何使用
|
| 251 |
+
|
| 252 |
+
这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
|
| 253 |
+
|
| 254 |
+
- 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
|
| 255 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 256 |
+
|
| 257 |
+
| 常用链接 |
|
| 258 |
+
| ------ |
|
| 259 |
+
| [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
|
| 260 |
+
| [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
| 261 |
+
| [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
| 262 |
+
| [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 263 |
+
|
| 264 |
+
### 训练
|
| 265 |
+
支持的训练方法:
|
| 266 |
+
|
| 267 |
+
| 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
|
| 268 |
+
| ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
|
| 269 |
+
| 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 270 |
+
| 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 271 |
+
| DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 272 |
+
| GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 273 |
+
| 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 274 |
+
| PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 275 |
+
| KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 276 |
+
| CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 277 |
+
| SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 278 |
+
| ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 279 |
+
| 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 280 |
+
| Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
预训练:
|
| 284 |
+
```shell
|
| 285 |
+
# 8*A100
|
| 286 |
+
NPROC_PER_NODE=8 \
|
| 287 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 288 |
+
swift pt \
|
| 289 |
+
--model Qwen/Qwen2.5-7B \
|
| 290 |
+
--dataset swift/chinese-c4 \
|
| 291 |
+
--streaming true \
|
| 292 |
+
--train_type full \
|
| 293 |
+
--deepspeed zero2 \
|
| 294 |
+
--output_dir output \
|
| 295 |
+
--max_steps 10000 \
|
| 296 |
+
...
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
微调:
|
| 300 |
+
```shell
|
| 301 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 302 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 303 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 304 |
+
--train_type lora \
|
| 305 |
+
--output_dir output \
|
| 306 |
+
...
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
RLHF:
|
| 310 |
+
```shell
|
| 311 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 312 |
+
--rlhf_type dpo \
|
| 313 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 314 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 315 |
+
--train_type lora \
|
| 316 |
+
--output_dir output \
|
| 317 |
+
...
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
### 推理
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 324 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 325 |
+
--stream true \
|
| 326 |
+
--infer_backend pt \
|
| 327 |
+
--max_new_tokens 2048
|
| 328 |
+
|
| 329 |
+
# LoRA
|
| 330 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 331 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 332 |
+
--adapters swift/test_lora \
|
| 333 |
+
--stream true \
|
| 334 |
+
--infer_backend pt \
|
| 335 |
+
--temperature 0 \
|
| 336 |
+
--max_new_tokens 2048
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
### 界面推理
|
| 340 |
+
```shell
|
| 341 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 342 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 343 |
+
--stream true \
|
| 344 |
+
--infer_backend pt \
|
| 345 |
+
--max_new_tokens 2048 \
|
| 346 |
+
--lang zh
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
### 部署
|
| 350 |
+
```shell
|
| 351 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 352 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 353 |
+
--infer_backend vllm
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
### 采样
|
| 357 |
+
```shell
|
| 358 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 359 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 360 |
+
--sampler_engine pt \
|
| 361 |
+
--num_return_sequences 5 \
|
| 362 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
### 评测
|
| 366 |
+
```shell
|
| 367 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 368 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 369 |
+
--infer_backend lmdeploy \
|
| 370 |
+
--eval_backend OpenCompass \
|
| 371 |
+
--eval_dataset ARC_c
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
### 量化
|
| 375 |
+
```shell
|
| 376 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 377 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 378 |
+
--quant_bits 4 --quant_method awq \
|
| 379 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 380 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
### 推送模型
|
| 384 |
+
```shell
|
| 385 |
+
swift export \
|
| 386 |
+
--model <model-path> \
|
| 387 |
+
--push_to_hub true \
|
| 388 |
+
--hub_model_id '<model-id>' \
|
| 389 |
+
--hub_token '<sdk-token>'
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
## 🏛 License
|
| 394 |
+
|
| 395 |
+
本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
|
| 396 |
+
|
| 397 |
+
## 📎 引用
|
| 398 |
+
|
| 399 |
+
```bibtex
|
| 400 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 401 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 402 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 403 |
+
year={2024},
|
| 404 |
+
eprint={2408.05517},
|
| 405 |
+
archivePrefix={arXiv},
|
| 406 |
+
primaryClass={cs.CL},
|
| 407 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 408 |
+
}
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
## Star History
|
| 412 |
+
|
| 413 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
ms-swift/.ipynb_checkpoints/dataset-checkpoint.json
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 2 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 3 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 4 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 5 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 6 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 7 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 8 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 9 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 10 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 11 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 12 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 13 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 14 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 15 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 16 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 17 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 18 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 19 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 20 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 21 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 22 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 23 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 24 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 25 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 26 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 27 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 28 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 29 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 30 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 31 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 32 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 33 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 34 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 35 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 36 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 37 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 38 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 39 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 40 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 41 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 42 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 43 |
+
{"messages": [{"role": "user", "content": "<audio>语音��了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 44 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 45 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 46 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 47 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 48 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 49 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 50 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 51 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 52 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 53 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 54 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 55 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 56 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 57 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 58 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 59 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 60 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
ms-swift/.ipynb_checkpoints/dataset_overlap5s716_gemini-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/.ipynb_checkpoints/gen_data-checkpoint.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def get_prompt_for_file(filename):
|
| 5 |
+
if 'isoverlap' in filename:
|
| 6 |
+
return overlap_prompt
|
| 7 |
+
# elif 'issilence' in filename:
|
| 8 |
+
# return silence_prompt
|
| 9 |
+
# elif 'speaker_segments' in filename:
|
| 10 |
+
# return speaker_prompt
|
| 11 |
+
# elif 'transcriptions' in filename:
|
| 12 |
+
# return transcript_prompt
|
| 13 |
+
# else:
|
| 14 |
+
# raise ValueError(f"No matching prompt found for {filename}")
|
| 15 |
+
# return None
|
| 16 |
+
|
| 17 |
+
output_path = "/root/ms-swift/dataset_Overlap2.json"
|
| 18 |
+
|
| 19 |
+
# with open(input_path, "r") as fin:
|
| 20 |
+
# input_data = json.load(fin)
|
| 21 |
+
|
| 22 |
+
www = "hello"
|
| 23 |
+
|
| 24 |
+
www = (
|
| 25 |
+
"# Dialogue Response Evaluation\n\n"
|
| 26 |
+
"**IMPORTANT:** Evaluation must include`<score>` rating.\n\n"
|
| 27 |
+
"Listen to the dialogue recording (two sentences, 1-second pause in between). Evaluate the quality of the **second sentence** as a response to the first, focusing on **text relevance** and the **appropriateness** of **Linguistic information (a range of paralinguistic information such as emotion/age/pitch/speed/volume)**.\n"
|
| 28 |
+
"**Note:** Focus on evaluating the appropriateness of the second sentence relative to the first, even if the first sentence itself contains contradictory information.\n\n"
|
| 29 |
+
"## Scoring Criteria\n\n"
|
| 30 |
+
"**1 points**: Text content is irrelevant or incorrect or illogical.(low intelligence)\n"
|
| 31 |
+
"**3 points**: Text is relevant, but paralinguistic information is **inappropriate** for the context.(low emotional quotient)\n"
|
| 32 |
+
"**5 points**: Text is relevant, and paralinguistic information is **appropriate** for the context, resulting in effective communication.(High intelligence and emotional intelligence.)\n\n"
|
| 33 |
+
"## Evaluation Requirements\n\n"
|
| 34 |
+
"Response **MUST** follow this format:\n\n"
|
| 35 |
+
"<score>X</score> (**X is 1, 3, or 5**)\n\n")
|
| 36 |
+
|
| 37 |
+
# www = (
|
| 38 |
+
# "# Interactional Dialogue Evaluation\n\n"
|
| 39 |
+
# "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
|
| 40 |
+
# "Listen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n"
|
| 41 |
+
# "**Response Relevance:** \n"
|
| 42 |
+
# "**logical consistency, topic coherence**\n"
|
| 43 |
+
# "**Interactional Fluency:**\n"
|
| 44 |
+
# "**Strictly detect dual-tracked vocal overlap >3s (cross-channel analysis)**\n"
|
| 45 |
+
# "**Pauses >5s between turns (must evaluate) \n\n**"
|
| 46 |
+
# "**Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
|
| 47 |
+
# "## Scoring Criteria\n"
|
| 48 |
+
# "Assign a single holistic score based on the combined evaluation:\n"
|
| 49 |
+
# "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
|
| 50 |
+
# "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
|
| 51 |
+
# "## Evaluation Output Format:\n"
|
| 52 |
+
# "Strictly follow this template:\n"
|
| 53 |
+
# "<response think>\n"
|
| 54 |
+
# "[Analysing Response Relevance and giving reasons for scoring...]\n"
|
| 55 |
+
# "</response think>\n"
|
| 56 |
+
# "<fluency think>\n"
|
| 57 |
+
# "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
|
| 58 |
+
# "</fluency think>\n"
|
| 59 |
+
# "<overall score>X</overall score>\n"
|
| 60 |
+
|
| 61 |
+
# )
|
| 62 |
+
# www = (
|
| 63 |
+
# "# Interactional Dialogue Evaluation\n\n"
|
| 64 |
+
# "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
|
| 65 |
+
# "Listen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n"
|
| 66 |
+
# "**Response Relevance:** \n"
|
| 67 |
+
# "**logical consistency, topic coherence**\n"
|
| 68 |
+
# "**Interactional Fluency:**\n"
|
| 69 |
+
# "**Strictly detect dual-tracked vocal overlap >3s (cross-channel analysis)**\n"
|
| 70 |
+
# "**Pauses >5s between turns (must evaluate) \n\n**"
|
| 71 |
+
# "**Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
|
| 72 |
+
# "## Scoring Criteria\n"
|
| 73 |
+
# "Assign a single holistic score based on the combined evaluation:\n"
|
| 74 |
+
# "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
|
| 75 |
+
# "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
|
| 76 |
+
# "## Evaluation Output Format:\n"
|
| 77 |
+
# "Strictly follow this template:\n"
|
| 78 |
+
# "<response think>\n"
|
| 79 |
+
# "[Analysing Response Relevance and giving reasons for scoring...]\n"
|
| 80 |
+
# "</response think>\n"
|
| 81 |
+
# "<fluency think>\n"
|
| 82 |
+
# "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
|
| 83 |
+
# "</fluency think>\n"
|
| 84 |
+
# "<overall score>X</overall score>\n"
|
| 85 |
+
|
| 86 |
+
# )
|
| 87 |
+
overlap_prompt = (
|
| 88 |
+
"Analyze the dual-channel audio and identify segments where multiple speakers are talking simultaneously for more than 3 seconds. \n"
|
| 89 |
+
"Simply tell me when the overlap starts and ends in MM:SS format. \n"
|
| 90 |
+
"Just one simple sentence about the overlap timing. Keep the word count within 40 words."
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
silence_prompt = (
|
| 94 |
+
"Analyze the dual-channel audio and identify segments where multiple speakers are silent for more than 3 seconds. \n"
|
| 95 |
+
"Simply tell me when the silence starts and ends in MM:SS format. \n"
|
| 96 |
+
"Just one simple sentence about the silence timing. Keep the word count within 40 words."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
speaker_prompt = (
|
| 100 |
+
"Analyze the dual-channel audio and detect individual speakers. \n"
|
| 101 |
+
"List the speaking segments for each speaker in MM:SS-MM:SS format. \n"
|
| 102 |
+
"Only output speaker labels and time segments in a similar format. Do not include any explanation.\n"
|
| 103 |
+
"Format the output like this example: \n"
|
| 104 |
+
"Speaker A: 00:00-00:13, 00:15-00:27, 00:33-00:37\n"
|
| 105 |
+
"Speaker B: 00:04-00:14, 00:27-00:32, 00:38-00:39 \n"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
transcript_prompt = (
|
| 109 |
+
"Analyze the dual-channel audio and transcript each speaker's sentences with timestamps. \n"
|
| 110 |
+
"List the speaking segments and transcript text for each speaker in MM:SS-MM:SS format. \n"
|
| 111 |
+
"Only output time segments, speaker labels, and transcript text in a similar format. Do not include any explanation.\n"
|
| 112 |
+
"Format the output like this example: \n"
|
| 113 |
+
"[00:00 - 00:13] Speaker A: transcript text \n"
|
| 114 |
+
"[00:15 - 00:27] Speaker B: transcript text \n"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Process files in the silence_overlaps directory
|
| 118 |
+
input_dir = "/root/ms-swift/silence_overlaps/only_overlap"
|
| 119 |
+
all_data = []
|
| 120 |
+
|
| 121 |
+
# Process each file
|
| 122 |
+
for filename in os.listdir(input_dir):
|
| 123 |
+
input_path = os.path.join(input_dir, filename)
|
| 124 |
+
|
| 125 |
+
# Get the appropriate prompt for this file
|
| 126 |
+
prompt = get_prompt_for_file(filename)
|
| 127 |
+
if prompt is None:
|
| 128 |
+
print(f"Skipping {filename} - no matching prompt found")
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
# Read input data
|
| 132 |
+
with open(input_path, "r") as fin:
|
| 133 |
+
input_data = json.load(fin)
|
| 134 |
+
|
| 135 |
+
# Process each item
|
| 136 |
+
for item in input_data:
|
| 137 |
+
data = {
|
| 138 |
+
"messages": [
|
| 139 |
+
{"role": "user",
|
| 140 |
+
"content": f"<audio>{prompt}"
|
| 141 |
+
},
|
| 142 |
+
{"role": "assistant", "content": item["model_output"]}
|
| 143 |
+
],
|
| 144 |
+
"audios": [
|
| 145 |
+
item["audio_url"]
|
| 146 |
+
]
|
| 147 |
+
}
|
| 148 |
+
all_data.append(data)
|
| 149 |
+
|
| 150 |
+
# Write all processed data to a single output file
|
| 151 |
+
with open(output_path, "w", encoding="utf-8") as fout:
|
| 152 |
+
for data in all_data:
|
| 153 |
+
json.dump(data, fout, ensure_ascii=False)
|
| 154 |
+
fout.write('\n')
|
ms-swift/.ipynb_checkpoints/overlap5s716_gemini-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/.ipynb_checkpoints/setup-checkpoint.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
# !/usr/bin/env python
|
| 3 |
+
import os
|
| 4 |
+
from setuptools import find_packages, setup
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def readme():
|
| 9 |
+
with open('README.md', encoding='utf-8') as f:
|
| 10 |
+
content = f.read()
|
| 11 |
+
return content
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
version_file = 'swift/version.py'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_version():
|
| 18 |
+
with open(version_file, 'r', encoding='utf-8') as f:
|
| 19 |
+
exec(compile(f.read(), version_file, 'exec'))
|
| 20 |
+
return locals()['__version__']
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_requirements(fname='requirements.txt', with_version=True):
|
| 24 |
+
"""
|
| 25 |
+
Parse the package dependencies listed in a requirements file but strips
|
| 26 |
+
specific versioning information.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
fname (str): path to requirements file
|
| 30 |
+
with_version (bool, default=False): if True include version specs
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
List[str]: list of requirements items
|
| 34 |
+
|
| 35 |
+
CommandLine:
|
| 36 |
+
python -c "import setup; print(setup.parse_requirements())"
|
| 37 |
+
"""
|
| 38 |
+
import re
|
| 39 |
+
import sys
|
| 40 |
+
from os.path import exists
|
| 41 |
+
require_fpath = fname
|
| 42 |
+
|
| 43 |
+
def parse_line(line):
|
| 44 |
+
"""
|
| 45 |
+
Parse information from a line in a requirements text file
|
| 46 |
+
"""
|
| 47 |
+
if line.startswith('-r '):
|
| 48 |
+
# Allow specifying requirements in other files
|
| 49 |
+
target = line.split(' ')[1]
|
| 50 |
+
relative_base = os.path.dirname(fname)
|
| 51 |
+
absolute_target = os.path.join(relative_base, target)
|
| 52 |
+
for info in parse_require_file(absolute_target):
|
| 53 |
+
yield info
|
| 54 |
+
else:
|
| 55 |
+
info = {'line': line}
|
| 56 |
+
if line.startswith('-e '):
|
| 57 |
+
info['package'] = line.split('#egg=')[1]
|
| 58 |
+
else:
|
| 59 |
+
# Remove versioning from the package
|
| 60 |
+
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
|
| 61 |
+
parts = re.split(pat, line, maxsplit=1)
|
| 62 |
+
parts = [p.strip() for p in parts]
|
| 63 |
+
|
| 64 |
+
info['package'] = parts[0]
|
| 65 |
+
if len(parts) > 1:
|
| 66 |
+
op, rest = parts[1:]
|
| 67 |
+
if ';' in rest:
|
| 68 |
+
# Handle platform specific dependencies
|
| 69 |
+
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
|
| 70 |
+
version, platform_deps = map(str.strip, rest.split(';'))
|
| 71 |
+
info['platform_deps'] = platform_deps
|
| 72 |
+
else:
|
| 73 |
+
version = rest # NOQA
|
| 74 |
+
info['version'] = (op, version)
|
| 75 |
+
yield info
|
| 76 |
+
|
| 77 |
+
def parse_require_file(fpath):
|
| 78 |
+
with open(fpath, 'r', encoding='utf-8') as f:
|
| 79 |
+
for line in f.readlines():
|
| 80 |
+
line = line.strip()
|
| 81 |
+
if line.startswith('http'):
|
| 82 |
+
print('skip http requirements %s' % line)
|
| 83 |
+
continue
|
| 84 |
+
if line and not line.startswith('#') and not line.startswith('--'):
|
| 85 |
+
for info in parse_line(line):
|
| 86 |
+
yield info
|
| 87 |
+
elif line and line.startswith('--find-links'):
|
| 88 |
+
eles = line.split()
|
| 89 |
+
for e in eles:
|
| 90 |
+
e = e.strip()
|
| 91 |
+
if 'http' in e:
|
| 92 |
+
info = dict(dependency_links=e)
|
| 93 |
+
yield info
|
| 94 |
+
|
| 95 |
+
def gen_packages_items():
|
| 96 |
+
items = []
|
| 97 |
+
deps_link = []
|
| 98 |
+
if exists(require_fpath):
|
| 99 |
+
for info in parse_require_file(require_fpath):
|
| 100 |
+
if 'dependency_links' not in info:
|
| 101 |
+
parts = [info['package']]
|
| 102 |
+
if with_version and 'version' in info:
|
| 103 |
+
parts.extend(info['version'])
|
| 104 |
+
if not sys.version.startswith('3.4'):
|
| 105 |
+
# apparently package_deps are broken in 3.4
|
| 106 |
+
platform_deps = info.get('platform_deps')
|
| 107 |
+
if platform_deps is not None:
|
| 108 |
+
parts.append(';' + platform_deps)
|
| 109 |
+
item = ''.join(parts)
|
| 110 |
+
items.append(item)
|
| 111 |
+
else:
|
| 112 |
+
deps_link.append(info['dependency_links'])
|
| 113 |
+
return items, deps_link
|
| 114 |
+
|
| 115 |
+
return gen_packages_items()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == '__main__':
|
| 119 |
+
install_requires, deps_link = parse_requirements('requirements.txt')
|
| 120 |
+
extra_requires = {}
|
| 121 |
+
all_requires = []
|
| 122 |
+
extra_requires['eval'], _ = parse_requirements('requirements/eval.txt')
|
| 123 |
+
extra_requires['swanlab'], _ = parse_requirements('requirements/swanlab.txt')
|
| 124 |
+
extra_requires['seq_parallel'], _ = parse_requirements('requirements/seq_parallel.txt')
|
| 125 |
+
all_requires.extend(install_requires)
|
| 126 |
+
all_requires.extend(extra_requires['eval'])
|
| 127 |
+
all_requires.extend(extra_requires['seq_parallel'])
|
| 128 |
+
all_requires.extend(extra_requires['swanlab'])
|
| 129 |
+
extra_requires['all'] = all_requires
|
| 130 |
+
|
| 131 |
+
setup(
|
| 132 |
+
name='ms_swift',
|
| 133 |
+
version=get_version(),
|
| 134 |
+
description='Swift: Scalable lightWeight Infrastructure for Fine-Tuning',
|
| 135 |
+
long_description=readme(),
|
| 136 |
+
long_description_content_type='text/markdown',
|
| 137 |
+
author='DAMO ModelScope teams',
|
| 138 |
+
author_email='contact@modelscope.cn',
|
| 139 |
+
keywords='python, petl, efficient tuners',
|
| 140 |
+
url='https://github.com/modelscope/swift',
|
| 141 |
+
packages=find_packages(exclude=('configs', 'demo')),
|
| 142 |
+
include_package_data=True,
|
| 143 |
+
package_data={
|
| 144 |
+
'': ['*.h', '*.cpp', '*.cu'],
|
| 145 |
+
},
|
| 146 |
+
classifiers=[
|
| 147 |
+
'Development Status :: 4 - Beta',
|
| 148 |
+
'License :: OSI Approved :: Apache Software License',
|
| 149 |
+
'Operating System :: OS Independent',
|
| 150 |
+
'Programming Language :: Python :: 3',
|
| 151 |
+
'Programming Language :: Python :: 3.8',
|
| 152 |
+
'Programming Language :: Python :: 3.9',
|
| 153 |
+
'Programming Language :: Python :: 3.10',
|
| 154 |
+
'Programming Language :: Python :: 3.11',
|
| 155 |
+
'Programming Language :: Python :: 3.12',
|
| 156 |
+
],
|
| 157 |
+
license='Apache License 2.0',
|
| 158 |
+
tests_require=parse_requirements('requirements/tests.txt'),
|
| 159 |
+
install_requires=install_requires,
|
| 160 |
+
extras_require=extra_requires,
|
| 161 |
+
entry_points={
|
| 162 |
+
'console_scripts': ['swift=swift.cli.main:cli_main', 'megatron=swift.cli._megatron.main:cli_main']
|
| 163 |
+
},
|
| 164 |
+
dependency_links=deps_link,
|
| 165 |
+
zip_safe=False)
|
ms-swift/.ipynb_checkpoints/test-checkpoint.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 2 |
+
swift infer \
|
| 3 |
+
--adapters /root/autodl-tmp/output_7B_SFT/v0-20250605-155458/checkpoint-1095 \
|
| 4 |
+
--stream true \
|
| 5 |
+
--temperature 0 \
|
| 6 |
+
--max_new_tokens 2048
|
ms-swift/.ipynb_checkpoints/train-checkpoint.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 3 |
+
--model /root/autodl-tmp/Qwen2.5-Omni-7B \
|
| 4 |
+
--dataset /root/ms-swift/dataset_Overlap2.json \
|
| 5 |
+
--train_type full \
|
| 6 |
+
--output_dir /root/autodl-tmp/output_7B_SFT \
|
| 7 |
+
--torch_dtype bfloat16 \
|
| 8 |
+
--num_train_epochs 3 \
|
| 9 |
+
--per_device_train_batch_size 1 \
|
| 10 |
+
--per_device_eval_batch_size 1 \
|
| 11 |
+
# ...
|
| 12 |
+
|
| 13 |
+
# # 8*A100
|
| 14 |
+
# NPROC_PER_NODE=8 \
|
| 15 |
+
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 16 |
+
# swift pt \
|
| 17 |
+
# --model Qwen/Qwen2.5-7B \
|
| 18 |
+
# --dataset swift/chinese-c4 \
|
| 19 |
+
# --streaming true \
|
| 20 |
+
# --train_type full \
|
| 21 |
+
# --deepspeed zero2 \
|
| 22 |
+
# --output_dir output \
|
| 23 |
+
# --max_steps 10000 \
|
| 24 |
+
# ...
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# --lora_rank 8 \
|
| 29 |
+
# --lora_alpha 32 \
|
| 30 |
+
# --target_modules all-linear \
|
| 31 |
+
# --gradient_accumulation_steps 16 \
|
| 32 |
+
# --eval_steps 50 \
|
| 33 |
+
# --save_steps 50 \
|
| 34 |
+
# --save_total_limit 2 \
|
| 35 |
+
# --logging_steps 5 \
|
| 36 |
+
# --max_length 2048 \
|
| 37 |
+
# --output_dir output \
|
| 38 |
+
# --system 'You are a helpful assistant.' \
|
| 39 |
+
# --warmup_ratio 0.05 \
|
| 40 |
+
# --dataloader_num_workers 4 \
|
| 41 |
+
# --model_author swift \
|
| 42 |
+
# --model_name swift-robot
|
ms-swift/asset/banner.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/dpo_data.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/grpo_clevr_count.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/grpo_code.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/grpo_countdown.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/grpo_countdown_1.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/grpo_geoqa.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/grpo_openr1_multimodal.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/kto_data.png
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/web-ui-en.jpg
ADDED
|
Git LFS Details
|
ms-swift/docs/resources/web-ui.jpg
ADDED
|
Git LFS Details
|
ms-swift/silence_overlaps.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28ae3d5b7926569fa96005246be5397fb0dc10bf23281006fc1791dac1771d5e
|
| 3 |
+
size 645024
|
ms-swift/silence_overlaps/.ipynb_checkpoints/clean_wrong-checkpoint.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
def parse_timestamp(timestamp: str) -> Tuple[int, int]:
|
| 6 |
+
"""Convert timestamp string like '00:15' to seconds."""
|
| 7 |
+
minutes, seconds = map(int, timestamp.split(':'))
|
| 8 |
+
return minutes * 60 + seconds
|
| 9 |
+
|
| 10 |
+
def extract_time_range(entry: str) -> Tuple[int, int]:
|
| 11 |
+
"""Extract start and end times from an entry like '[00:00 - 00:13]'."""
|
| 12 |
+
match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\]', entry)
|
| 13 |
+
if not match:
|
| 14 |
+
return None
|
| 15 |
+
start_time = parse_timestamp(match.group(1))
|
| 16 |
+
end_time = parse_timestamp(match.group(2))
|
| 17 |
+
return (start_time, end_time)
|
| 18 |
+
|
| 19 |
+
def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
|
| 20 |
+
"""Check if two time ranges overlap."""
|
| 21 |
+
start1, end1 = range1
|
| 22 |
+
start2, end2 = range2
|
| 23 |
+
return not (end1 <= start2 or end2 <= start1)
|
| 24 |
+
|
| 25 |
+
def clean_transcript(transcript: str) -> str:
|
| 26 |
+
"""Clean a single transcript by removing overlapping segments."""
|
| 27 |
+
lines = transcript.split('\n')
|
| 28 |
+
cleaned_lines = []
|
| 29 |
+
time_ranges = []
|
| 30 |
+
|
| 31 |
+
for line in lines:
|
| 32 |
+
if not line.strip():
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
time_range = extract_time_range(line)
|
| 36 |
+
if time_range is None:
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
# Check for overlaps with existing ranges
|
| 40 |
+
has_conflict = False
|
| 41 |
+
for existing_range in time_ranges:
|
| 42 |
+
if has_overlap(time_range, existing_range):
|
| 43 |
+
has_conflict = True
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
if not has_conflict:
|
| 47 |
+
time_ranges.append(time_range)
|
| 48 |
+
cleaned_lines.append(line)
|
| 49 |
+
|
| 50 |
+
return '\n'.join(cleaned_lines)
|
| 51 |
+
|
| 52 |
+
def process_file(input_file: str, output_file: str):
|
| 53 |
+
"""Process the JSON file and clean overlapping transcriptions."""
|
| 54 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 55 |
+
data = json.load(f)
|
| 56 |
+
|
| 57 |
+
if isinstance(data, dict):
|
| 58 |
+
data = [data]
|
| 59 |
+
|
| 60 |
+
cleaned_data = []
|
| 61 |
+
for entry in data:
|
| 62 |
+
if 'model_output' in entry:
|
| 63 |
+
entry['model_output'] = clean_transcript(entry['model_output'])
|
| 64 |
+
cleaned_data.append(entry)
|
| 65 |
+
|
| 66 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 67 |
+
json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
input_file = 'silence_overlaps/overlap5s_transcriptions.json'
|
| 71 |
+
output_file = 'silence_overlaps/cleaned_transcriptions.json'
|
| 72 |
+
process_file(input_file, output_file)
|
| 73 |
+
print(f"Cleaned transcriptions have been saved to {output_file}")
|
ms-swift/silence_overlaps/.ipynb_checkpoints/cleaned_transcriptions-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/.ipynb_checkpoints/delete_transcript2-checkpoint.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
ms-swift/silence_overlaps/700/merge_and_shuffle_json.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import Dict, List, Any
|
| 5 |
+
|
| 6 |
+
def load_json_files() -> List[Dict[str, Any]]:
|
| 7 |
+
"""加载当前目录下所有JSON文件的内容"""
|
| 8 |
+
json_data = []
|
| 9 |
+
for filename in os.listdir('.'):
|
| 10 |
+
if filename.endswith('.json'):
|
| 11 |
+
try:
|
| 12 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
| 13 |
+
data = json.load(f)
|
| 14 |
+
json_data.append(data)
|
| 15 |
+
print(f"已加载文件: {filename}")
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f"加载文件 {filename} 时出错: {e}")
|
| 18 |
+
return json_data
|
| 19 |
+
|
| 20 |
+
def merge_and_shuffle(json_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 21 |
+
"""合并所有JSON数据并随机打乱条目"""
|
| 22 |
+
merged_data = {}
|
| 23 |
+
# 合并所有JSON数据
|
| 24 |
+
for data in json_data:
|
| 25 |
+
merged_data.update(data)
|
| 26 |
+
|
| 27 |
+
# 提取所有条目并打乱顺序
|
| 28 |
+
items = list(merged_data.items())
|
| 29 |
+
random.shuffle(items)
|
| 30 |
+
|
| 31 |
+
# 创建新的有序字典
|
| 32 |
+
shuffled_data = {}
|
| 33 |
+
for key, value in items:
|
| 34 |
+
shuffled_data[key] = value
|
| 35 |
+
|
| 36 |
+
return shuffled_data
|
| 37 |
+
|
| 38 |
+
def save_shuffled_data(shuffled_data: Dict[str, Any], output_file: str = 'merged_shuffled.json') -> None:
|
| 39 |
+
"""将打乱后的数据保存到JSON文件"""
|
| 40 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 41 |
+
json.dump(shuffled_data, f, ensure_ascii=False, indent=2)
|
| 42 |
+
print(f"已保存到文件: {output_file}")
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
# 设置随机种子,以便结果可重现
|
| 46 |
+
random.seed(42)
|
| 47 |
+
|
| 48 |
+
# 加载JSON文件
|
| 49 |
+
json_data = load_json_files()
|
| 50 |
+
if not json_data:
|
| 51 |
+
print("未找到JSON文件!")
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
# 合并并打乱数据
|
| 55 |
+
shuffled_data = merge_and_shuffle(json_data)
|
| 56 |
+
|
| 57 |
+
# 保存结果
|
| 58 |
+
save_shuffled_data(shuffled_data)
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
main()
|
ms-swift/silence_overlaps/700/split_train_test.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
from glob import glob
|
| 5 |
+
|
| 6 |
+
random.seed(42) # 保证可复现
|
| 7 |
+
|
| 8 |
+
json_files = glob(os.path.join(os.path.dirname(__file__), '*.json'))
|
| 9 |
+
|
| 10 |
+
for file in json_files:
|
| 11 |
+
base = os.path.basename(file)
|
| 12 |
+
if base.startswith('split_train_test') or base.endswith('_train.json') or base.endswith('_test.json'):
|
| 13 |
+
continue
|
| 14 |
+
with open(file, 'r', encoding='utf-8') as f:
|
| 15 |
+
try:
|
| 16 |
+
data = json.load(f)
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error reading {file}: {e}")
|
| 19 |
+
continue
|
| 20 |
+
|
| 21 |
+
# 将数据转换为列表格式
|
| 22 |
+
data_list = data if isinstance(data, list) else list(data.values())
|
| 23 |
+
random.shuffle(data_list)
|
| 24 |
+
|
| 25 |
+
# 选择5条数据作为测试集
|
| 26 |
+
test_data = data_list[:5]
|
| 27 |
+
train_data = data_list[5:]
|
| 28 |
+
|
| 29 |
+
train_file = os.path.splitext(file)[0] + '_train.json'
|
| 30 |
+
test_file = os.path.splitext(file)[0] + '_test.json'
|
| 31 |
+
|
| 32 |
+
with open(train_file, 'w', encoding='utf-8') as f:
|
| 33 |
+
json.dump(train_data, f, ensure_ascii=False, indent=2)
|
| 34 |
+
with open(test_file, 'w', encoding='utf-8') as f:
|
| 35 |
+
json.dump(test_data, f, ensure_ascii=False, indent=2)
|
| 36 |
+
print(f"{base} 已分为 {train_file} (训练集: {len(train_data)}条) 和 {test_file} (测试集: {len(test_data)}条)")
|
ms-swift/silence_overlaps/700/train/overlap5s_isoverlap_train.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/700/train/overlap5s_speaker_segments_train.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/clean_wrong.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
def parse_timestamp(timestamp: str) -> Tuple[int, int]:
|
| 6 |
+
"""Convert timestamp string like '00:15' to seconds."""
|
| 7 |
+
minutes, seconds = map(int, timestamp.split(':'))
|
| 8 |
+
return minutes * 60 + seconds
|
| 9 |
+
|
| 10 |
+
def extract_time_range(entry: str) -> Tuple[int, int]:
|
| 11 |
+
"""Extract start and end times from an entry like '[00:00 - 00:13]'."""
|
| 12 |
+
match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\]', entry)
|
| 13 |
+
if not match:
|
| 14 |
+
return None
|
| 15 |
+
start_time = parse_timestamp(match.group(1))
|
| 16 |
+
end_time = parse_timestamp(match.group(2))
|
| 17 |
+
return (start_time, end_time)
|
| 18 |
+
|
| 19 |
+
def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
|
| 20 |
+
"""Check if two time ranges overlap."""
|
| 21 |
+
start1, end1 = range1
|
| 22 |
+
start2, end2 = range2
|
| 23 |
+
return not (end1 <= start2 or end2 <= start1)
|
| 24 |
+
|
| 25 |
+
def clean_transcript(transcript: str) -> str:
|
| 26 |
+
"""Clean a single transcript by removing overlapping segments."""
|
| 27 |
+
lines = transcript.split('\n')
|
| 28 |
+
cleaned_lines = []
|
| 29 |
+
time_ranges = []
|
| 30 |
+
|
| 31 |
+
for line in lines:
|
| 32 |
+
if not line.strip():
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
time_range = extract_time_range(line)
|
| 36 |
+
if time_range is None:
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
# Check for overlaps with existing ranges
|
| 40 |
+
has_conflict = False
|
| 41 |
+
for existing_range in time_ranges:
|
| 42 |
+
if has_overlap(time_range, existing_range):
|
| 43 |
+
has_conflict = True
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
if not has_conflict:
|
| 47 |
+
time_ranges.append(time_range)
|
| 48 |
+
cleaned_lines.append(line)
|
| 49 |
+
|
| 50 |
+
return '\n'.join(cleaned_lines)
|
| 51 |
+
|
| 52 |
+
def process_file(input_file: str, output_file: str):
|
| 53 |
+
"""Process the JSON file and clean overlapping transcriptions."""
|
| 54 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 55 |
+
data = json.load(f)
|
| 56 |
+
|
| 57 |
+
if isinstance(data, dict):
|
| 58 |
+
data = [data]
|
| 59 |
+
|
| 60 |
+
cleaned_data = []
|
| 61 |
+
for entry in data:
|
| 62 |
+
if 'model_output' in entry:
|
| 63 |
+
entry['model_output'] = clean_transcript(entry['model_output'])
|
| 64 |
+
cleaned_data.append(entry)
|
| 65 |
+
|
| 66 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 67 |
+
json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
input_file = 'silence_overlaps/overlap5s_transcriptions.json'
|
| 71 |
+
output_file = 'silence_overlaps/cleaned_transcriptions.json'
|
| 72 |
+
process_file(input_file, output_file)
|
| 73 |
+
print(f"Cleaned transcriptions have been saved to {output_file}")
|
ms-swift/silence_overlaps/cleaned_transcriptions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/overlap5s_isoverlap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/overlap5s_speaker_segments.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/overlap5s_transcriptions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/silence_isoverlaps.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/silence_issilence.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/transcriptions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/swift/ui/llm_train/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import asyncio
|
| 3 |
+
import sys
|
| 4 |
+
from asyncio.subprocess import PIPE, STDOUT
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
async def run_and_get_log(*args, timeout=None):
|
| 8 |
+
process = await asyncio.create_subprocess_exec(*args, stdout=PIPE, stderr=STDOUT)
|
| 9 |
+
lines = []
|
| 10 |
+
while True:
|
| 11 |
+
try:
|
| 12 |
+
line = await asyncio.wait_for(process.stdout.readline(), timeout)
|
| 13 |
+
except asyncio.TimeoutError:
|
| 14 |
+
break
|
| 15 |
+
else:
|
| 16 |
+
if not line:
|
| 17 |
+
break
|
| 18 |
+
else:
|
| 19 |
+
lines.append(str(line))
|
| 20 |
+
return process, lines
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run_command_in_subprocess(*args, timeout):
|
| 24 |
+
if sys.platform == 'win32':
|
| 25 |
+
loop = asyncio.ProactorEventLoop()
|
| 26 |
+
asyncio.set_event_loop(loop)
|
| 27 |
+
else:
|
| 28 |
+
loop = asyncio.new_event_loop()
|
| 29 |
+
asyncio.set_event_loop(loop)
|
| 30 |
+
process, lines = loop.run_until_complete(run_and_get_log(*args, timeout=timeout))
|
| 31 |
+
return (loop, process), lines
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def close_loop(handler):
|
| 35 |
+
loop, process = handler
|
| 36 |
+
process.kill()
|
| 37 |
+
loop.close()
|
ms-swift/swift/utils/__pycache__/logger.cpython-310.pyc
ADDED
|
Binary file (3.08 kB). View file
|
|
|
ms-swift/swift/utils/__pycache__/torch_utils.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
ms-swift/swift/utils/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
ms-swift/swift/utils/utils.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import datetime as dt
|
| 3 |
+
import fnmatch
|
| 4 |
+
import glob
|
| 5 |
+
import importlib
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import shutil
|
| 10 |
+
import socket
|
| 11 |
+
import subprocess
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from transformers import HfArgumentParser, enable_full_determinism, set_seed
|
| 20 |
+
from transformers.utils import strtobool
|
| 21 |
+
|
| 22 |
+
from .env import is_dist, is_dist_ta
|
| 23 |
+
from .logger import get_logger
|
| 24 |
+
from .np_utils import stat_array
|
| 25 |
+
|
| 26 |
+
logger = get_logger()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check_json_format(obj: Any, token_safe: bool = True) -> Any:
|
| 30 |
+
if obj is None or isinstance(obj, (int, float, str, complex)): # bool is a subclass of int
|
| 31 |
+
return obj
|
| 32 |
+
if isinstance(obj, bytes):
|
| 33 |
+
return '<<<bytes>>>'
|
| 34 |
+
if isinstance(obj, (torch.dtype, torch.device)):
|
| 35 |
+
obj = str(obj)
|
| 36 |
+
return obj[len('torch.'):] if obj.startswith('torch.') else obj
|
| 37 |
+
|
| 38 |
+
if isinstance(obj, Sequence):
|
| 39 |
+
res = []
|
| 40 |
+
for x in obj:
|
| 41 |
+
res.append(check_json_format(x, token_safe))
|
| 42 |
+
elif isinstance(obj, Mapping):
|
| 43 |
+
res = {}
|
| 44 |
+
for k, v in obj.items():
|
| 45 |
+
if token_safe and isinstance(k, str) and '_token' in k and isinstance(v, str):
|
| 46 |
+
res[k] = None
|
| 47 |
+
else:
|
| 48 |
+
res[k] = check_json_format(v, token_safe)
|
| 49 |
+
else:
|
| 50 |
+
if token_safe:
|
| 51 |
+
unsafe_items = {}
|
| 52 |
+
for k, v in obj.__dict__.items():
|
| 53 |
+
if '_token' in k:
|
| 54 |
+
unsafe_items[k] = v
|
| 55 |
+
setattr(obj, k, None)
|
| 56 |
+
res = repr(obj)
|
| 57 |
+
# recover
|
| 58 |
+
for k, v in unsafe_items.items():
|
| 59 |
+
setattr(obj, k, v)
|
| 60 |
+
else:
|
| 61 |
+
res = repr(obj) # e.g. function, object
|
| 62 |
+
return res
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_version(work_dir: str) -> int:
|
| 66 |
+
if os.path.isdir(work_dir):
|
| 67 |
+
fnames = os.listdir(work_dir)
|
| 68 |
+
else:
|
| 69 |
+
fnames = []
|
| 70 |
+
v_list = [-1]
|
| 71 |
+
for fname in fnames:
|
| 72 |
+
m = re.match(r'v(\d+)', fname)
|
| 73 |
+
if m is None:
|
| 74 |
+
continue
|
| 75 |
+
v = m.group(1)
|
| 76 |
+
v_list.append(int(v))
|
| 77 |
+
return max(v_list) + 1
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def format_time(seconds):
|
| 81 |
+
days = int(seconds // (24 * 3600))
|
| 82 |
+
hours = int((seconds % (24 * 3600)) // 3600)
|
| 83 |
+
minutes = int((seconds % 3600) // 60)
|
| 84 |
+
seconds = int(seconds % 60)
|
| 85 |
+
|
| 86 |
+
if days > 0:
|
| 87 |
+
time_str = f'{days}d {hours}h {minutes}m {seconds}s'
|
| 88 |
+
elif hours > 0:
|
| 89 |
+
time_str = f'{hours}h {minutes}m {seconds}s'
|
| 90 |
+
elif minutes > 0:
|
| 91 |
+
time_str = f'{minutes}m {seconds}s'
|
| 92 |
+
else:
|
| 93 |
+
time_str = f'{seconds}s'
|
| 94 |
+
|
| 95 |
+
return time_str
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def deep_getattr(obj, attr: str, default=None):
|
| 99 |
+
attrs = attr.split('.')
|
| 100 |
+
for a in attrs:
|
| 101 |
+
if obj is None:
|
| 102 |
+
break
|
| 103 |
+
if isinstance(obj, dict):
|
| 104 |
+
obj = obj.get(a, default)
|
| 105 |
+
else:
|
| 106 |
+
obj = getattr(obj, a, default)
|
| 107 |
+
return obj
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def seed_everything(seed: Optional[int] = None, full_determinism: bool = False, *, verbose: bool = True) -> int:
|
| 111 |
+
|
| 112 |
+
if seed is None:
|
| 113 |
+
seed_max = np.iinfo(np.int32).max
|
| 114 |
+
seed = random.randint(0, seed_max)
|
| 115 |
+
|
| 116 |
+
if full_determinism:
|
| 117 |
+
enable_full_determinism(seed)
|
| 118 |
+
else:
|
| 119 |
+
set_seed(seed)
|
| 120 |
+
if verbose:
|
| 121 |
+
logger.info(f'Global seed set to {seed}')
|
| 122 |
+
return seed
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def add_version_to_work_dir(work_dir: str) -> str:
|
| 126 |
+
"""add version"""
|
| 127 |
+
version = _get_version(work_dir)
|
| 128 |
+
time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
|
| 129 |
+
sub_folder = f'v{version}-{time}'
|
| 130 |
+
if (dist.is_initialized() and is_dist()) or is_dist_ta():
|
| 131 |
+
obj_list = [sub_folder]
|
| 132 |
+
dist.broadcast_object_list(obj_list)
|
| 133 |
+
sub_folder = obj_list[0]
|
| 134 |
+
|
| 135 |
+
work_dir = os.path.join(work_dir, sub_folder)
|
| 136 |
+
return work_dir
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
_T = TypeVar('_T')
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def parse_args(class_type: Type[_T], argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
|
| 143 |
+
parser = HfArgumentParser([class_type])
|
| 144 |
+
if argv is None:
|
| 145 |
+
argv = sys.argv[1:]
|
| 146 |
+
if len(argv) > 0 and argv[0].endswith('.json'):
|
| 147 |
+
json_path = os.path.abspath(os.path.expanduser(argv[0]))
|
| 148 |
+
args, = parser.parse_json_file(json_path)
|
| 149 |
+
remaining_args = argv[1:]
|
| 150 |
+
else:
|
| 151 |
+
args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True)
|
| 152 |
+
return args, remaining_args
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def lower_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
|
| 156 |
+
# The lower bound satisfying the condition "cond".
|
| 157 |
+
while lo < hi:
|
| 158 |
+
mid = (lo + hi) >> 1
|
| 159 |
+
if cond(mid):
|
| 160 |
+
hi = mid
|
| 161 |
+
else:
|
| 162 |
+
lo = mid + 1
|
| 163 |
+
return lo
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
|
| 167 |
+
# The upper bound satisfying the condition "cond".
|
| 168 |
+
while lo < hi:
|
| 169 |
+
mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
|
| 170 |
+
if cond(mid):
|
| 171 |
+
lo = mid
|
| 172 |
+
else:
|
| 173 |
+
hi = mid - 1
|
| 174 |
+
return lo
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def test_time(func: Callable[[], _T],
|
| 178 |
+
number: int = 1,
|
| 179 |
+
warmup: int = 0,
|
| 180 |
+
timer: Optional[Callable[[], float]] = None) -> _T:
|
| 181 |
+
# timer: e.g. time_synchronize
|
| 182 |
+
timer = timer if timer is not None else time.perf_counter
|
| 183 |
+
|
| 184 |
+
ts = []
|
| 185 |
+
res = None
|
| 186 |
+
# warmup
|
| 187 |
+
for _ in range(warmup):
|
| 188 |
+
res = func()
|
| 189 |
+
|
| 190 |
+
for _ in range(number):
|
| 191 |
+
t1 = timer()
|
| 192 |
+
res = func()
|
| 193 |
+
t2 = timer()
|
| 194 |
+
ts.append(t2 - t1)
|
| 195 |
+
|
| 196 |
+
ts = np.array(ts)
|
| 197 |
+
_, stat_str = stat_array(ts)
|
| 198 |
+
# print
|
| 199 |
+
logger.info(f'time[number={number}]: {stat_str}')
|
| 200 |
+
return res
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def read_multi_line(addi_prompt: str = '') -> str:
|
| 204 |
+
res = []
|
| 205 |
+
prompt = f'<<<{addi_prompt} '
|
| 206 |
+
while True:
|
| 207 |
+
text = input(prompt) + '\n'
|
| 208 |
+
prompt = ''
|
| 209 |
+
res.append(text)
|
| 210 |
+
if text.endswith('#\n'):
|
| 211 |
+
res[-1] = text[:-2]
|
| 212 |
+
break
|
| 213 |
+
return ''.join(res)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def subprocess_run(command: List[str], env: Optional[Dict[str, str]] = None, stdout=None, stderr=None):
|
| 217 |
+
# stdoutm stderr: e.g. subprocess.PIPE.
|
| 218 |
+
resp = subprocess.run(command, env=env, stdout=stdout, stderr=stderr)
|
| 219 |
+
resp.check_returncode()
|
| 220 |
+
return resp
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_env_args(args_name: str, type_func: Callable[[str], _T], default_value: Optional[_T]) -> Optional[_T]:
|
| 224 |
+
args_name_upper = args_name.upper()
|
| 225 |
+
value = os.getenv(args_name_upper)
|
| 226 |
+
if value is None:
|
| 227 |
+
value = default_value
|
| 228 |
+
log_info = (f'Setting {args_name}: {default_value}. '
|
| 229 |
+
f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.')
|
| 230 |
+
else:
|
| 231 |
+
if type_func is bool:
|
| 232 |
+
value = strtobool(value)
|
| 233 |
+
value = type_func(value)
|
| 234 |
+
log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.'
|
| 235 |
+
logger.info_once(log_info)
|
| 236 |
+
return value
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def find_free_port(start_port: Optional[int] = None, retry: int = 100) -> int:
|
| 240 |
+
if start_port is None:
|
| 241 |
+
start_port = 0
|
| 242 |
+
for port in range(start_port, start_port + retry):
|
| 243 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 244 |
+
try:
|
| 245 |
+
sock.bind(('', port))
|
| 246 |
+
port = sock.getsockname()[1]
|
| 247 |
+
break
|
| 248 |
+
except OSError:
|
| 249 |
+
pass
|
| 250 |
+
return port
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def copy_files_by_pattern(source_dir, dest_dir, patterns):
|
| 254 |
+
if not os.path.exists(dest_dir):
|
| 255 |
+
os.makedirs(dest_dir)
|
| 256 |
+
|
| 257 |
+
if isinstance(patterns, str):
|
| 258 |
+
patterns = [patterns]
|
| 259 |
+
|
| 260 |
+
for pattern in patterns:
|
| 261 |
+
pattern_parts = pattern.split(os.path.sep)
|
| 262 |
+
if len(pattern_parts) > 1:
|
| 263 |
+
subdir_pattern = os.path.sep.join(pattern_parts[:-1])
|
| 264 |
+
file_pattern = pattern_parts[-1]
|
| 265 |
+
|
| 266 |
+
for root, dirs, files in os.walk(source_dir):
|
| 267 |
+
rel_path = os.path.relpath(root, source_dir)
|
| 268 |
+
if rel_path == '.' or (rel_path != '.' and not fnmatch.fnmatch(rel_path, subdir_pattern)):
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
for file in files:
|
| 272 |
+
if fnmatch.fnmatch(file, file_pattern):
|
| 273 |
+
file_path = os.path.join(root, file)
|
| 274 |
+
target_dir = os.path.join(dest_dir, rel_path)
|
| 275 |
+
if not os.path.exists(target_dir):
|
| 276 |
+
os.makedirs(target_dir)
|
| 277 |
+
dest_file = os.path.join(target_dir, file)
|
| 278 |
+
|
| 279 |
+
if not os.path.exists(dest_file):
|
| 280 |
+
shutil.copy2(file_path, dest_file)
|
| 281 |
+
else:
|
| 282 |
+
search_path = os.path.join(source_dir, pattern)
|
| 283 |
+
matched_files = glob.glob(search_path)
|
| 284 |
+
|
| 285 |
+
for file_path in matched_files:
|
| 286 |
+
if os.path.isfile(file_path):
|
| 287 |
+
file_name = os.path.basename(file_path)
|
| 288 |
+
destination = os.path.join(dest_dir, file_name)
|
| 289 |
+
if not os.path.exists(destination):
|
| 290 |
+
shutil.copy2(file_path, destination)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def split_list(ori_list, num_shards):
|
| 294 |
+
idx_list = np.linspace(0, len(ori_list), num_shards + 1)
|
| 295 |
+
shard = []
|
| 296 |
+
for i in range(len(idx_list) - 1):
|
| 297 |
+
shard.append(ori_list[int(idx_list[i]):int(idx_list[i + 1])])
|
| 298 |
+
return shard
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def patch_getattr(obj_cls, item_name: str):
|
| 302 |
+
if hasattr(obj_cls, '_patch'): # avoid double patch
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
def __new_getattr__(self, key: str):
|
| 306 |
+
try:
|
| 307 |
+
return super(self.__class__, self).__getattr__(key)
|
| 308 |
+
except AttributeError:
|
| 309 |
+
if item_name in dir(self):
|
| 310 |
+
item = getattr(self, item_name)
|
| 311 |
+
return getattr(item, key)
|
| 312 |
+
raise
|
| 313 |
+
|
| 314 |
+
obj_cls.__getattr__ = __new_getattr__
|
| 315 |
+
obj_cls._patch = True
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def import_external_file(file_path: str):
|
| 319 |
+
file_path = os.path.abspath(os.path.expanduser(file_path))
|
| 320 |
+
py_dir, py_file = os.path.split(file_path)
|
| 321 |
+
assert os.path.isdir(py_dir), f'py_dir: {py_dir}'
|
| 322 |
+
sys.path.insert(0, py_dir)
|
| 323 |
+
return importlib.import_module(py_file.split('.', 1)[0])
|
ms-swift/tests/__init__.py
ADDED
|
File without changes
|
ms-swift/tests/app/test_app.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def test_llm():
|
| 2 |
+
from swift.llm import app_main, AppArguments
|
| 3 |
+
app_main(AppArguments(model='Qwen/Qwen2.5-0.5B-Instruct'))
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_lora():
|
| 7 |
+
from swift.llm import app_main, AppArguments
|
| 8 |
+
app_main(AppArguments(adapters='swift/test_lora', lang='en', studio_title='小黄'))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_mllm():
|
| 12 |
+
from swift.llm import app_main, AppArguments
|
| 13 |
+
app_main(AppArguments(model='Qwen/Qwen2-VL-7B-Instruct', stream=True))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_audio():
|
| 17 |
+
from swift.llm import AppArguments, app_main, DeployArguments, run_deploy
|
| 18 |
+
deploy_args = DeployArguments(model='Qwen/Qwen2-Audio-7B-Instruct', infer_backend='pt', verbose=False)
|
| 19 |
+
|
| 20 |
+
with run_deploy(deploy_args, return_url=True) as url:
|
| 21 |
+
app_main(AppArguments(model='Qwen2-Audio-7B-Instruct', base_url=url, stream=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
test_mllm()
|
ms-swift/tests/llm/data/multi_modal_1.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"query": "<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>55555", "response": "66666"}
|
| 2 |
+
{"query": "<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img><img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>eeeee", "response": "fffff", "history": [["hello", "123"]]}
|
| 3 |
+
{"query": "EEEEE", "response": "FFFFF", "history": [["AAAAA", "BBBBB"], ["CCCCC", "DDDDD"]]}
|
ms-swift/tests/models/test_flash_attn.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from swift.llm import get_model_tokenizer
|
| 2 |
+
|
| 3 |
+
if __name__ == '__main__':
|
| 4 |
+
# model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', attn_impl='flash_attn')
|
| 5 |
+
# model, tokenizer = get_model_tokenizer('AIDC-AI/Ovis2-2B', attn_impl='flash_attn')
|
| 6 |
+
# model, tokenizer = get_model_tokenizer('OpenGVLab/InternVL2-2B', attn_impl='flash_attn')
|
| 7 |
+
model, tokenizer = get_model_tokenizer('Shanghai_AI_Laboratory/internlm3-8b-instruct', attn_impl='flash_attn')
|
| 8 |
+
print(model)
|
ms-swift/tests/test_align/test_rlhf_loss.py
ADDED
|
File without changes
|
ms-swift/tests/test_align/test_template/test_agent.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['SWIFT_DEBUG'] = '1'
|
| 4 |
+
|
| 5 |
+
system = 'You are a helpful assistant.'
|
| 6 |
+
|
| 7 |
+
tools = [{
|
| 8 |
+
'type': 'function',
|
| 9 |
+
'function': {
|
| 10 |
+
'name': 'get_current_weather',
|
| 11 |
+
'description': 'Get the current weather in a given location',
|
| 12 |
+
'parameters': {
|
| 13 |
+
'type': 'object',
|
| 14 |
+
'properties': {
|
| 15 |
+
'location': {
|
| 16 |
+
'type': 'string',
|
| 17 |
+
'description': 'The city and state, e.g. San Francisco, CA'
|
| 18 |
+
},
|
| 19 |
+
'unit': {
|
| 20 |
+
'type': 'string',
|
| 21 |
+
'enum': ['celsius', 'fahrenheit']
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
'required': ['location']
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
}, {
|
| 28 |
+
'name_for_model': 'tool2',
|
| 29 |
+
'name_for_human': '工具2',
|
| 30 |
+
'description': 'Tool2的描述',
|
| 31 |
+
}]
|
| 32 |
+
|
| 33 |
+
glm4_tools = [{
|
| 34 |
+
'type': 'function',
|
| 35 |
+
'function': {
|
| 36 |
+
'name': 'realtime_aqi',
|
| 37 |
+
'description': '天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息',
|
| 38 |
+
'parameters': {
|
| 39 |
+
'type': 'object',
|
| 40 |
+
'properties': {
|
| 41 |
+
'city': {
|
| 42 |
+
'description': '城市名'
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
'required': ['city']
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}]
|
| 49 |
+
glm4_tool_messasges = [
|
| 50 |
+
{
|
| 51 |
+
'role': 'tool',
|
| 52 |
+
'content': '{"city": "北京", "aqi": "10", "unit": "celsius"}'
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
'role': 'tool',
|
| 56 |
+
'content': '{"city": "上海", "aqi": "72", "unit": "fahrenheit"}'
|
| 57 |
+
},
|
| 58 |
+
]
|
| 59 |
+
glm4_query = '北京和上海今天的天气情况'
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _infer(engine, num_tools: int = 1, agent_tools=None, tool_messages=None, query=None):
|
| 63 |
+
if agent_tools is None:
|
| 64 |
+
agent_tools = tools
|
| 65 |
+
if tool_messages is None:
|
| 66 |
+
tool_messages = []
|
| 67 |
+
for _ in range(num_tools):
|
| 68 |
+
tool_messages.append({
|
| 69 |
+
'role': 'tool',
|
| 70 |
+
'content': '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
|
| 71 |
+
})
|
| 72 |
+
stop = [engine.default_template.agent_template.keyword.observation]
|
| 73 |
+
query = query or "How's the weather in Beijing today?"
|
| 74 |
+
infer_request = InferRequest([{'role': 'user', 'content': query}], tools=agent_tools)
|
| 75 |
+
request_config = RequestConfig(max_tokens=512, stop=stop, temperature=0)
|
| 76 |
+
resp_list = engine.infer([infer_request], request_config=request_config)
|
| 77 |
+
response = resp_list[0].choices[0].message.content
|
| 78 |
+
toolcall = resp_list[0].choices[0].message.tool_calls[0].function
|
| 79 |
+
print(f'response: {response}')
|
| 80 |
+
print(f'toolcall: {toolcall}')
|
| 81 |
+
assert toolcall is not None
|
| 82 |
+
infer_request.messages.append({'role': 'assistant', 'content': response})
|
| 83 |
+
infer_request.messages += tool_messages
|
| 84 |
+
resp_list = engine.infer([infer_request], request_config=request_config)
|
| 85 |
+
response2 = resp_list[0].choices[0].message.content
|
| 86 |
+
print(f'response2: {response2}')
|
| 87 |
+
infer_request.messages.append({'role': 'assistant', 'content': response2})
|
| 88 |
+
return infer_request.messages
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_react_en():
|
| 92 |
+
agent_template = agent_templates['react_en']()
|
| 93 |
+
new_system = agent_template._format_tools(tools, system)
|
| 94 |
+
assert len(new_system) == 1144
|
| 95 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 96 |
+
template = engine.default_template
|
| 97 |
+
template.agent_template = agent_template
|
| 98 |
+
messages = _infer(engine)
|
| 99 |
+
assert messages[-1]['content'] == (
|
| 100 |
+
'Thought: The current temperature in Beijing is 32 degrees Celsius, and the condition is sunny '
|
| 101 |
+
'with a humidity of 50%.\nFinal Answer: The current temperature in Beijing is 32 degrees Celsius,'
|
| 102 |
+
' and the condition is sunny with a humidity of 50%.')
|
| 103 |
+
template.set_mode('train')
|
| 104 |
+
encoded = template.encode({'messages': messages})
|
| 105 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 106 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 107 |
+
|
| 108 |
+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
|
| 109 |
+
data = dataset[6]
|
| 110 |
+
data['messages'].insert(1, data['messages'][1])
|
| 111 |
+
data['messages'].insert(3, data['messages'][3])
|
| 112 |
+
template.template_backend = 'swift'
|
| 113 |
+
encoded = template.encode(data)
|
| 114 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 115 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_react_zh():
|
| 119 |
+
agent_template = agent_templates['react_zh']()
|
| 120 |
+
new_system = agent_template._format_tools(tools, system)
|
| 121 |
+
assert len(new_system) == 712
|
| 122 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 123 |
+
template = engine.default_template
|
| 124 |
+
template.agent_template = agent_template
|
| 125 |
+
_infer(engine)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def test_qwen_en():
|
| 129 |
+
agent_template = agent_templates['qwen_en']()
|
| 130 |
+
new_system = agent_template._format_tools(tools, system)
|
| 131 |
+
assert len(new_system) == 879
|
| 132 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 133 |
+
template = engine.default_template
|
| 134 |
+
template.agent_template = agent_template
|
| 135 |
+
messages = _infer(engine)
|
| 136 |
+
assert messages[-1]['content'] == (
|
| 137 |
+
'✿RETURN✿: Today in Beijing, the temperature is 32°C with sunny conditions and the humidity '
|
| 138 |
+
'is at 50%. Enjoy the nice weather!')
|
| 139 |
+
template.set_mode('train')
|
| 140 |
+
encoded = template.encode({'messages': messages})
|
| 141 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 142 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 143 |
+
|
| 144 |
+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
|
| 145 |
+
data = dataset[6]
|
| 146 |
+
data['messages'].insert(1, data['messages'][1])
|
| 147 |
+
data['messages'].insert(3, data['messages'][3])
|
| 148 |
+
template.template_backend = 'swift'
|
| 149 |
+
encoded = template.encode(data)
|
| 150 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 151 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def test_qwen_zh():
|
| 155 |
+
agent_template = agent_templates['qwen_zh']()
|
| 156 |
+
new_system = agent_template._format_tools(tools, system)
|
| 157 |
+
assert len(new_system) == 577
|
| 158 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 159 |
+
template = engine.default_template
|
| 160 |
+
template.agent_template = agent_template
|
| 161 |
+
_infer(engine)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_qwen_en_parallel():
|
| 165 |
+
agent_template = agent_templates['qwen_en_parallel']()
|
| 166 |
+
new_system = agent_template._format_tools(tools, system)
|
| 167 |
+
assert len(new_system) == 1012
|
| 168 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 169 |
+
template = engine.default_template
|
| 170 |
+
template.agent_template = agent_template
|
| 171 |
+
messages = _infer(engine, num_tools=2)
|
| 172 |
+
assert messages[-1]['content'] == (
|
| 173 |
+
'✿RETURN✿: Today in Beijing, the temperature is 32 degrees Celsius with sunny conditions '
|
| 174 |
+
'and the humidity is at 50%. Enjoy the nice weather!')
|
| 175 |
+
template.set_mode('train')
|
| 176 |
+
encoded = template.encode({'messages': messages})
|
| 177 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 178 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 179 |
+
|
| 180 |
+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
|
| 181 |
+
data = dataset[6]
|
| 182 |
+
data['messages'].insert(1, data['messages'][1])
|
| 183 |
+
data['messages'].insert(3, data['messages'][3])
|
| 184 |
+
template.template_backend = 'swift'
|
| 185 |
+
encoded = template.encode(data)
|
| 186 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 187 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test_qwen_zh_parallel():
|
| 191 |
+
agent_template = agent_templates['qwen_zh_parallel']()
|
| 192 |
+
new_system = agent_template._format_tools(tools, system)
|
| 193 |
+
assert len(new_system) == 688
|
| 194 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 195 |
+
template = engine.default_template
|
| 196 |
+
template.agent_template = agent_template
|
| 197 |
+
_infer(engine, num_tools=2)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def test_hermes():
|
| 201 |
+
agent_template = agent_templates['hermes']()
|
| 202 |
+
new_system = agent_template._format_tools(tools, system)
|
| 203 |
+
assert len(new_system) == 875
|
| 204 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 205 |
+
template = engine.default_template
|
| 206 |
+
template.agent_template = agent_template
|
| 207 |
+
messages = _infer(engine, num_tools=2)
|
| 208 |
+
template.template_backend = 'jinja'
|
| 209 |
+
messages2 = _infer(engine, num_tools=2)
|
| 210 |
+
assert messages[-1]['content'] == messages2[-1]['content'] == (
|
| 211 |
+
'Today in Beijing, the temperature is 32 degrees Celsius with sunny conditions '
|
| 212 |
+
'and the humidity is at 50%. Enjoy the nice weather!')
|
| 213 |
+
template.set_mode('train')
|
| 214 |
+
encoded = template.encode({'messages': messages})
|
| 215 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 216 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 217 |
+
|
| 218 |
+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
|
| 219 |
+
data = dataset[6]
|
| 220 |
+
data['messages'].insert(1, data['messages'][1])
|
| 221 |
+
data['messages'].insert(3, data['messages'][3])
|
| 222 |
+
template.template_backend = 'swift'
|
| 223 |
+
encoded = template.encode(data)
|
| 224 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 225 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 226 |
+
template.template_backend = 'jinja'
|
| 227 |
+
encoded2 = template.encode(data)
|
| 228 |
+
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
|
| 229 |
+
print(f'labels: {template.safe_decode(encoded2["labels"])}')
|
| 230 |
+
assert encoded['input_ids'] == encoded2['input_ids'][:-1]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def test_toolbench():
|
| 234 |
+
agent_template = agent_templates['toolbench']()
|
| 235 |
+
new_system = agent_template._format_tools(tools, system)
|
| 236 |
+
assert len(new_system) == 1833
|
| 237 |
+
engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
|
| 238 |
+
template = engine.default_template
|
| 239 |
+
template.agent_template = agent_template
|
| 240 |
+
_infer(engine)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def test_glm4():
|
| 244 |
+
agent_template = agent_templates['glm4']()
|
| 245 |
+
new_system = agent_template._format_tools(tools, system)
|
| 246 |
+
assert len(new_system) == 846
|
| 247 |
+
engine = PtEngine('ZhipuAI/glm-4-9b-chat')
|
| 248 |
+
template = engine.default_template
|
| 249 |
+
template.agent_template = agent_template
|
| 250 |
+
_infer(engine, agent_tools=glm4_tools, tool_messages=glm4_tool_messasges, query=glm4_query)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_glm4_0414():
|
| 254 |
+
agent_template = agent_templates['glm4_0414']()
|
| 255 |
+
new_system = agent_template._format_tools(tools, system)
|
| 256 |
+
assert len(new_system) == 769
|
| 257 |
+
engine = PtEngine('ZhipuAI/GLM-4-9B-0414')
|
| 258 |
+
template = engine.default_template
|
| 259 |
+
template.agent_template = agent_template
|
| 260 |
+
messages = _infer(engine, agent_tools=glm4_tools, tool_messages=glm4_tool_messasges, query=glm4_query)
|
| 261 |
+
assert messages[-1]['content'] == '根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。'
|
| 262 |
+
template.set_mode('train')
|
| 263 |
+
encoded = template.encode({'messages': messages})
|
| 264 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 265 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 266 |
+
|
| 267 |
+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
|
| 268 |
+
data = dataset[6]
|
| 269 |
+
data['messages'].insert(1, data['messages'][1])
|
| 270 |
+
data['messages'].insert(3, data['messages'][3])
|
| 271 |
+
template.template_backend = 'swift'
|
| 272 |
+
encoded = template.encode(data)
|
| 273 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 274 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def test_llama3():
|
| 278 |
+
agent_template = agent_templates['llama3']()
|
| 279 |
+
engine = PtEngine('LLM-Research/Llama-3.2-3B-Instruct')
|
| 280 |
+
template = engine.default_template
|
| 281 |
+
template.agent_template = agent_template
|
| 282 |
+
messages = _infer(engine)
|
| 283 |
+
|
| 284 |
+
template.set_mode('train')
|
| 285 |
+
encoded = template.encode({'messages': messages})
|
| 286 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 287 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 288 |
+
|
| 289 |
+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
|
| 290 |
+
data = dataset[6]
|
| 291 |
+
data['messages'].insert(1, data['messages'][1])
|
| 292 |
+
data['messages'].insert(3, data['messages'][3])
|
| 293 |
+
template.template_backend = 'swift'
|
| 294 |
+
encoded = template.encode(data)
|
| 295 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 296 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def test_llama4():
|
| 300 |
+
agent_template = agent_templates['llama4']()
|
| 301 |
+
engine = PtEngine('LLM-Research/Llama-4-Scout-17B-16E-Instruct')
|
| 302 |
+
template = engine.default_template
|
| 303 |
+
template.agent_template = agent_template
|
| 304 |
+
messages = _infer(engine)
|
| 305 |
+
template.set_mode('train')
|
| 306 |
+
encoded = template.encode({'messages': messages})
|
| 307 |
+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
|
| 308 |
+
print(f'labels: {template.safe_decode(encoded["labels"])}')
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
if __name__ == '__main__':
|
| 312 |
+
from swift.plugin import agent_templates
|
| 313 |
+
from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
|
| 314 |
+
# test_react_en()
|
| 315 |
+
# test_react_zh()
|
| 316 |
+
# test_qwen_en()
|
| 317 |
+
# test_qwen_zh()
|
| 318 |
+
# test_qwen_en_parallel()
|
| 319 |
+
# test_qwen_zh_parallel()
|
| 320 |
+
test_hermes()
|
| 321 |
+
# test_toolbench()
|
| 322 |
+
# test_glm4()
|
| 323 |
+
# test_glm4_0414()
|
| 324 |
+
# test_llama3()
|
| 325 |
+
# test_llama4()
|
ms-swift/tests/test_align/test_template/test_audio.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _infer_model(pt_engine, system=None, messages=None, audios=None):
|
| 7 |
+
seed_everything(42)
|
| 8 |
+
request_config = RequestConfig(max_tokens=128, temperature=0)
|
| 9 |
+
if messages is None:
|
| 10 |
+
messages = []
|
| 11 |
+
if system is not None:
|
| 12 |
+
messages += [{'role': 'system', 'content': system}]
|
| 13 |
+
messages += [{'role': 'user', 'content': '你好'}]
|
| 14 |
+
resp = pt_engine.infer([{'messages': messages}], request_config=request_config)
|
| 15 |
+
response = resp[0].choices[0].message.content
|
| 16 |
+
messages += [{'role': 'assistant', 'content': response}]
|
| 17 |
+
messages += [{'role': 'user', 'content': '<audio>这段语音说了什么'}]
|
| 18 |
+
else:
|
| 19 |
+
messages = messages.copy()
|
| 20 |
+
if audios is None:
|
| 21 |
+
audios = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav']
|
| 22 |
+
resp = pt_engine.infer([{'messages': messages, 'audios': audios}], request_config=request_config)
|
| 23 |
+
response = resp[0].choices[0].message.content
|
| 24 |
+
messages += [{'role': 'assistant', 'content': response}]
|
| 25 |
+
logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}')
|
| 26 |
+
return response
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_qwen_audio():
|
| 30 |
+
pt_engine = PtEngine('Qwen/Qwen-Audio-Chat')
|
| 31 |
+
_infer_model(pt_engine)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_qwen2_audio():
|
| 35 |
+
# transformers==4.48.3
|
| 36 |
+
pt_engine = PtEngine('Qwen/Qwen2-Audio-7B-Instruct')
|
| 37 |
+
messages = [{'role': 'user', 'content': '<audio>'}]
|
| 38 |
+
audios = ['https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav']
|
| 39 |
+
response = _infer_model(pt_engine, messages=messages, audios=audios)
|
| 40 |
+
pt_engine.default_template.template_backend = 'jinja'
|
| 41 |
+
response2 = _infer_model(pt_engine, messages=messages, audios=audios)
|
| 42 |
+
assert response == response2 == 'Yes, the speaker is female and in her twenties.'
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_xcomposer2d5_ol():
|
| 46 |
+
pt_engine = PtEngine('Shanghai_AI_Laboratory/internlm-xcomposer2d5-ol-7b:audio')
|
| 47 |
+
_infer_model(pt_engine)
|
| 48 |
+
pt_engine.default_template.template_backend = 'jinja'
|
| 49 |
+
_infer_model(pt_engine)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_step_audio_chat():
|
| 53 |
+
pt_engine = PtEngine('stepfun-ai/Step-Audio-Chat')
|
| 54 |
+
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<audio>'}])
|
| 55 |
+
assert response == ('是的呢,今天天气晴朗,阳光明媚,微风和煦,非常适合外出活动。天空湛蓝,白云朵朵,让人心情愉悦。希望你能好好享受这美好的一天!')
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_qwen2_5_omni():
|
| 59 |
+
USE_AUDIO_IN_VIDEO = True
|
| 60 |
+
os.environ['USE_AUDIO_IN_VIDEO'] = str(USE_AUDIO_IN_VIDEO)
|
| 61 |
+
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B')
|
| 62 |
+
response = _infer_model(pt_engine)
|
| 63 |
+
pt_engine.default_template.template_backend = 'jinja'
|
| 64 |
+
response2 = _infer_model(pt_engine)
|
| 65 |
+
assert response == response2
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == '__main__':
|
| 69 |
+
from swift.llm import PtEngine, RequestConfig
|
| 70 |
+
from swift.utils import get_logger, seed_everything
|
| 71 |
+
logger = get_logger()
|
| 72 |
+
# test_qwen_audio()
|
| 73 |
+
# test_qwen2_audio()
|
| 74 |
+
# test_xcomposer2d5_ol()
|
| 75 |
+
# test_step_audio_chat()
|
| 76 |
+
test_qwen2_5_omni()
|