diff --git a/SpecForge-ext/.devcontainer/Dockerfile b/SpecForge-ext/.devcontainer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..8ffb0d0328f12064b311869d19aa60df32cd7484 --- /dev/null +++ b/SpecForge-ext/.devcontainer/Dockerfile @@ -0,0 +1,32 @@ +FROM lmsysorg/sglang:dev + +# Create non-root user with specified UID and GID +# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908. +ARG HOST_UID=1003 +ARG HOST_GID=1003 +RUN groupadd -g $HOST_GID devuser && \ + useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser + +# Give devuser sudo access +RUN apt-get update && apt-get install -y sudo && \ + echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +# Set up oh-my-zsh for devuser +RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \ + cp /root/.zshrc /home/devuser/.zshrc && \ + cp /root/.vimrc /home/devuser/.vimrc && \ + cp /root/.tmux.conf /home/devuser/.tmux.conf && \ + sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \ + chown -R devuser:devuser /home/devuser/ + +# Set workspace directory and ownership +WORKDIR /sgl-workspace/sglang +RUN chown -R devuser:devuser /sgl-workspace + +# Switch to devuser +USER devuser + +# Install rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/SpecForge-ext/.devcontainer/devcontainer.json b/SpecForge-ext/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000000000000000000000000000000..b2dbad2a745763b273af79b742640461f18b7894 --- /dev/null +++ b/SpecForge-ext/.devcontainer/devcontainer.json @@ -0,0 +1,30 @@ +{ + "name": "sglang", + "build": { + "dockerfile": "Dockerfile" + }, + "remoteUser": "devuser", + "customizations": { + "vscode": { + "extensions": [ + // Python development + "ms-python.python", + "charliermarsh.ruff", + // Rust development + "rust-lang.rust-analyzer", + "tamasfe.even-better-toml" + ] + } + }, + "forwardPorts": [], + "runArgs": [ + "--gpus", + "all" + ], + // The two lines below ensures that your local changes in the sglang + // repo is automatically synced to the sglang pip package installed + // in the dev docker container. You can remove / comment out these + // two lines if you prefer to sync code changes manually. + "workspaceMount": "source=${localWorkspaceFolder},target=/sgl-workspace/specforge,type=bind", + "workspaceFolder": "/sgl-workspace/specforge" +} diff --git a/SpecForge-ext/.editorconfig b/SpecForge-ext/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..030a7293dcb6294d1ac26f262c761a9ac0a91052 --- /dev/null +++ b/SpecForge-ext/.editorconfig @@ -0,0 +1,25 @@ +# https://editorconfig.org/ + +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{json,yaml,yml}] +indent_size = 2 + +[*.md] +indent_size = 2 +x-soft-wrap-text = true + +[*.rst] +indent_size = 4 +x-soft-wrap-text = true + +[Makefile] +indent_style = tab diff --git a/SpecForge-ext/.github/CODEOWNERS b/SpecForge-ext/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..e4dbc44f0f9b24da1ad6a96eff14abe45f184255 --- /dev/null +++ b/SpecForge-ext/.github/CODEOWNERS @@ -0,0 +1,11 @@ +.github @FrankLeeeee +/specforge/core @FrankLeeeee +/specforge/data @zyksir @sleepcoo @shuaills +/specforge/layers @FrankLeeeee @FlamingoPg @sleepcoo @shuaills +/specforge/modeling @FlamingoPg @sleepcoo @shuaills @FrankLeeeee +/tests @FrankLeeeee +/assets @FrankLeeeee @zhyncs +/examples @shuaills @sleepcoo @FlamingoPg +/configs @FrankLeeeee @FlamingoPg +/benchmarks @FrankLeeeee +/scripts @shuaills @sleepcoo @FlamingoPg diff --git a/SpecForge-ext/.github/pull_request_template.md b/SpecForge-ext/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..296468dfb8c84c38784759283db598959572a91f --- /dev/null +++ b/SpecForge-ext/.github/pull_request_template.md @@ -0,0 +1,30 @@ + + +## Motivation + + + +## Modifications + + + +## Related Issues + + + +## Accuracy Test + + + +## Benchmark & Profiling + + + +## Checklist + +- [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit). +- [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci). +- [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci). +- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html) and [Accuracy Results](https://docs.sglang.ai/references/accuracy_evaluation.html). +- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR. +- [ ] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR. diff --git a/SpecForge-ext/.isort.cfg b/SpecForge-ext/.isort.cfg new file mode 100644 index 0000000000000000000000000000000000000000..82a27d81c14cfbef583d39fe8a51bb635437b35e --- /dev/null +++ b/SpecForge-ext/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +profile=black +known_first_party=sgl-eagle diff --git a/SpecForge-ext/LICENSE b/SpecForge-ext/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..909b8ff34ce3ff391ec5ecd1d2388d0c5b1cd4b3 --- /dev/null +++ b/SpecForge-ext/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 sgl-project + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/SpecForge-ext/MANIFEST.in b/SpecForge-ext/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..7e3c8f05614505dc88691fa12babee86f8d1995e --- /dev/null +++ b/SpecForge-ext/MANIFEST.in @@ -0,0 +1,2 @@ +include requirements.txt +include version.txt diff --git a/SpecForge-ext/README.md b/SpecForge-ext/README.md new file mode 100644 index 0000000000000000000000000000000000000000..141963e25880130165ca60532f9f0b4292adea41 --- /dev/null +++ b/SpecForge-ext/README.md @@ -0,0 +1,70 @@ +
+logo + +[![documentation](https://img.shields.io/badge/๐Ÿ“–-Documentation-red.svg?style=flat)](https://docs.sglang.ai/SpecForge/) +[![SpecBundle](https://img.shields.io/badge/๐Ÿค—%20SpecBundle-yellow.svg?style=flat)](https://huggingface.co/collections/lmsys/specbundle) +[![DeepWiki](https://img.shields.io/badge/DeepWiki-SpecForge-blue.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACwAAAAyCAYAAAAnWDnqAAAAAXNSR0IArs4c6QAAA05JREFUaEPtmUtyEzEQhtWTQyQLHNak2AB7ZnyXZMEjXMGeK/AIi+QuHrMnbChYY7MIh8g01fJoopFb0uhhEqqcbWTp06/uv1saEDv4O3n3dV60RfP947Mm9/SQc0ICFQgzfc4CYZoTPAswgSJCCUJUnAAoRHOAUOcATwbmVLWdGoH//PB8mnKqScAhsD0kYP3j/Yt5LPQe2KvcXmGvRHcDnpxfL2zOYJ1mFwrryWTz0advv1Ut4CJgf5uhDuDj5eUcAUoahrdY/56ebRWeraTjMt/00Sh3UDtjgHtQNHwcRGOC98BJEAEymycmYcWwOprTgcB6VZ5JK5TAJ+fXGLBm3FDAmn6oPPjR4rKCAoJCal2eAiQp2x0vxTPB3ALO2CRkwmDy5WohzBDwSEFKRwPbknEggCPB/imwrycgxX2NzoMCHhPkDwqYMr9tRcP5qNrMZHkVnOjRMWwLCcr8ohBVb1OMjxLwGCvjTikrsBOiA6fNyCrm8V1rP93iVPpwaE+gO0SsWmPiXB+jikdf6SizrT5qKasx5j8ABbHpFTx+vFXp9EnYQmLx02h1QTTrl6eDqxLnGjporxl3NL3agEvXdT0WmEost648sQOYAeJS9Q7bfUVoMGnjo4AZdUMQku50McDcMWcBPvr0SzbTAFDfvJqwLzgxwATnCgnp4wDl6Aa+Ax283gghmj+vj7feE2KBBRMW3FzOpLOADl0Isb5587h/U4gGvkt5v60Z1VLG8BhYjbzRwyQZemwAd6cCR5/XFWLYZRIMpX39AR0tjaGGiGzLVyhse5C9RKC6ai42ppWPKiBagOvaYk8lO7DajerabOZP46Lby5wKjw1HCRx7p9sVMOWGzb/vA1hwiWc6jm3MvQDTogQkiqIhJV0nBQBTU+3okKCFDy9WwferkHjtxib7t3xIUQtHxnIwtx4mpg26/HfwVNVDb4oI9RHmx5WGelRVlrtiw43zboCLaxv46AZeB3IlTkwouebTr1y2NjSpHz68WNFjHvupy3q8TFn3Hos2IAk4Ju5dCo8B3wP7VPr/FGaKiG+T+v+TQqIrOqMTL1VdWV1DdmcbO8KXBz6esmYWYKPwDL5b5FA1a0hwapHiom0r/cKaoqr+27/XcrS5UwSMbQAAAABJRU5ErkJggg==)](https://deepwiki.com/sgl-project/SpecForge) + +[![github badge](https://img.shields.io/badge/๐Ÿ“ƒ%20LMSYS-Blog-black.svg?style=flat)](https://lmsys.org/blog/2025-07-25-spec-forge/) +[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://sgl-fru7574.slack.com/archives/C09784E3EN6) +[![license](https://img.shields.io/badge/License-MIT%202.0-blue)](./LICENSE) + +
+ +## ๐Ÿ“ Overview + +SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference. + +We have seen many open-source projects for speculative decoding, but most of them are not well-maintained or not directly compatible with SGLang. We prepared this project because we wish that the open-source community can enjoy a speculative decoding framework that is +- regularly maintained by the SpecForge team: the code is runnable out-of-the-box +- directly compatible with SGLang: there is no additional efforts for porting to SGLang +- provide performant training capabilities: we provided online/offline/tensor-parallel/FSDP to suit your needs + + +Check out [**our documentation**](https://docs.sglang.ai/SpecForge/) to get started. + + +## ๐Ÿš€ Accelerate with SpecBundle + +SpecBundle is a collection of production-grade speculative decoding models that are released by the SpecForge team and our industry partners. They provide higher acceptance rate compared to the existing open-source checkpoints over a wide range of domains. Together with SGLang, you can experience up to 4x speedup for inference. Check out our resources below: + + +| Item | Link | +| --- | --- | +| ๐Ÿ“ Documentation | [Link](https://docs.sglang.io/SpecForge/community_resources/specbundle.html) | +| ๐Ÿ“Š Performance Dashboard | [Link](https://docs.sglang.io/SpecForge/SpecBundle/index.html) | +| ๐Ÿค— Hugging Face Collection | [Link](https://huggingface.co/collections/lmsys/specbundle) | + + +## ๐ŸŽ‰ News + +- [2025-12] ๐ŸŽ‰ Released SpecBundle (phase 1) and SpecForge v0.2. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-12-23-spec-bundle-phase-1/) +- [2025-12] ๐Ÿ”” Released the roadmap for 2026 Q1. +- [2025-08] ๐Ÿ”” SpecForge is listed as a [flagship project](https://lmsys.org/about/) in LMSYS. Congratulations to the SpecForge team! +- [2025-08] ๐Ÿ”ฅ SpecForge powered the Eagle3 draft model for GPT-OSS. Check out the blog at [LMSYS.org](https://lmsys.org/blog/2025-08-27-gpt-oss/) +- [2025-07] ๐Ÿ”ฅ SpecForge is released together with Llama4-Eagle3 checkpoints. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-07-25-spec-forge/) + +## โœจ Acknowledgements + +acknowledgements + +We would like to express our sincere gratitude to the official EAGLE team, especially Hongyang Zhang and Yuhui Li, for their invaluable contributions and support. Our thanks also go to the NVIDIA teamโ€”particularly Avery H and Izzy Puttermanโ€”and to the Google team, especially Ying Wang, for their insightful discussions and generous assistance throughout the project. + +We are especially grateful to Meituan for their strong backing and meaningful contributions, which played a vital role in driving this project forward. + +This project has also been inspired by many outstanding open-source projects from the LLM community, including [EAGLE](https://github.com/SafeAILab/EAGLE), [BaldEagle](https://github.com/NickL77/BaldEagle), and [TensorRT-Model-Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) and others. Their contributions and shared knowledge have greatly benefited our work. + +## ๐Ÿ’ก Special Thanks to Voltage Park + +We would like to extend our sincere thanks to [Voltage Park](https://www.voltagepark.com/), our official infrastructure partner. As part of a formal collaboration with the SGLang team, Voltage Park provided critical GPU resources that empowered us to train and evaluate large-scale speculative decoding models efficiently and reliably. This partnership was instrumental in making SpecForge possible. We deeply appreciate Voltage Parkโ€™s mission to make cutting-edge AI infrastructure more accessible, and we look forward to continued collaboration as we push the boundaries of open-source LLM serving and optimization. + +## ๐Ÿ“ƒ Citation + +```bibtex +@misc{specforge2025, + title={SpecForge: Train speculative decoding models effortlessly}, + author={Shenggui Li, Yikai Zhu, Chao Wang, Fan Yin, Shuai Shi, Yubo Wang, Yi Zhang, Yingyi Huang, Haoshuai Zheng, Yineng Zhang}, + year={2025}, + publisher={GitHub}, + howpublished={\url{https://github.com/sgl-project/specforge}}, +} diff --git a/SpecForge-ext/analyze_accept_length.sh b/SpecForge-ext/analyze_accept_length.sh new file mode 100644 index 0000000000000000000000000000000000000000..1c8ac756f5a974716c4b152755ac37f226e78153 --- /dev/null +++ b/SpecForge-ext/analyze_accept_length.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +# ๅˆ†ๆžaccept length็š„่„šๆœฌ + +echo "==========================================" +echo "Accept Length Analysis" +echo "==========================================" +echo "" + +# ๆฃ€ๆŸฅresults็›ฎๅฝ• +if [ ! -d "results" ]; then + echo "Error: results directory not found" + exit 1 +fi + +# ๆŸฅๆ‰พๆ‰€ๆœ‰็ป“ๆžœๆ–‡ไปถ +result_files=$(ls results/*.jsonl 2>/dev/null) + +if [ -z "$result_files" ]; then + echo "No result files found in results/ directory" + echo "" + echo "Please run the benchmark first:" + echo " python benchmarks/bench_eagle3.py ..." + exit 1 +fi + +echo "Found result files:" +ls -lh results/*.jsonl +echo "" +echo "==========================================" +echo "" + +# ๅˆ†ๆžๆฏไธช็ป“ๆžœๆ–‡ไปถ +for file in $result_files; do + filename=$(basename "$file") + echo "File: $filename" + echo "----------------------------------------" + + # ๆฃ€ๆŸฅๆ–‡ไปถๆ˜ฏๅฆๅŒ…ๅซmtbench็ป“ๆžœ + if grep -q "mtbench" "$file"; then + # ๆๅ–accept_length + echo "Accept lengths:" + cat "$file" | jq -r '.mtbench[0].metrics[] | " Sample \(.sample_id): accept_length=\(.accept_length // "N/A"), output_tokens=\(.output_tokens // "N/A")"' 2>/dev/null + + echo "" + echo "Statistics:" + # ่ฎก็ฎ—ๅนณๅ‡ๅ€ผ + avg_accept=$(cat "$file" | jq -r '.mtbench[0].metrics[] | .accept_length' 2>/dev/null | awk '{sum+=$1; count++} END {if(count>0) printf " Average accept_length: %.4f\n", sum/count; else print " No data"}') + echo "$avg_accept" + + # ่ฎก็ฎ—ๆœ€ๅฐๅ€ผๅ’Œๆœ€ๅคงๅ€ผ + min_accept=$(cat "$file" | jq -r '.mtbench[0].metrics[] | .accept_length' 2>/dev/null | sort -n | head -1) + max_accept=$(cat "$file" | jq -r '.mtbench[0].metrics[] | .accept_length' 2>/dev/null | sort -n | tail -1) + echo " Min accept_length: $min_accept" + echo " Max accept_length: $max_accept" + + # ๆ ทๆœฌๆ•ฐ้‡ + sample_count=$(cat "$file" | jq -r '.mtbench[0].metrics | length' 2>/dev/null) + echo " Total samples: $sample_count" + else + echo " No mtbench results found in this file" + fi + + echo "" + echo "==========================================" + echo "" +done + +# ๅฆ‚ๆžœๆœ‰baselineๅ’Œtrained็š„็ป“ๆžœ๏ผŒ่ฟ›่กŒๅฏนๆฏ” +baseline_file=$(ls results/baseline*.jsonl 2>/dev/null | head -1) +trained_file=$(ls results/trained*.jsonl 2>/dev/null | head -1) + +if [ -n "$baseline_file" ] && [ -n "$trained_file" ]; then + echo "Comparison: Baseline vs Trained" + echo "----------------------------------------" + + baseline_avg=$(cat "$baseline_file" | jq -r '.mtbench[0].metrics[] | .accept_length' 2>/dev/null | awk '{sum+=$1; count++} END {if(count>0) print sum/count}') + trained_avg=$(cat "$trained_file" | jq -r '.mtbench[0].metrics[] | .accept_length' 2>/dev/null | awk '{sum+=$1; count++} END {if(count>0) print sum/count}') + + if [ -n "$baseline_avg" ] && [ -n "$trained_avg" ]; then + echo "Baseline average: $baseline_avg" + echo "Trained average: $trained_avg" + + # ่ฎก็ฎ—ๆๅ‡็™พๅˆ†ๆฏ” + improvement=$(echo "$baseline_avg $trained_avg" | awk '{printf "%.2f%%", ($2-$1)/$1*100}') + echo "Improvement: $improvement" + fi + echo "" +fi + +echo "Done!" diff --git a/SpecForge-ext/assets/logo.svg b/SpecForge-ext/assets/logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..7f619f50a0be61ade41e82599a40db2a45b3c376 --- /dev/null +++ b/SpecForge-ext/assets/logo.svg @@ -0,0 +1,938 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/SpecForge-ext/configs/deepseek-v2-lite-eagle3.json b/SpecForge-ext/configs/deepseek-v2-lite-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..da12c0fb4444a55773ac0f84f4360f3476a39d09 --- /dev/null +++ b/SpecForge-ext/configs/deepseek-v2-lite-eagle3.json @@ -0,0 +1,39 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 100000, + "eos_token_id": 100001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 10944, + "max_position_embeddings": 163840, + "max_window_layers": 64, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 1, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 40.0, + "mscale": 0.707, + "mscale_all_dim": 0.707, + "original_max_position_embeddings": 4096, + "rope_type": "yarn" + }, + "rope_theta": 10000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 102400, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/deepseek-v3-671b-eagle3.json b/SpecForge-ext/configs/deepseek-v3-671b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..147a5fdcd32c7ccd83248eec16dc709ed34e8bce --- /dev/null +++ b/SpecForge-ext/configs/deepseek-v3-671b-eagle3.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 29, + 57 + ], + "use_aux_hidden_state": true + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 40960, + "max_position_embeddings": 163840, + "model_type": "llama", + "num_attention_heads": 56, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.51.0", + "use_cache": true, + "vocab_size": 129280, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/gemma3-1b-eagle3.json b/SpecForge-ext/configs/gemma3-1b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..e5e74eb16a3e47ac9ff4357106ff7c2afe4186da --- /dev/null +++ b/SpecForge-ext/configs/gemma3-1b-eagle3.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 1152, + "initializer_range": 0.02, + "intermediate_size": 6912, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 4, + "num_hidden_layers": 1, + "num_key_value_heads": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262145, + "draft_vocab_size": 32000, + "target_model_type": "gemma3_text" +} diff --git a/SpecForge-ext/configs/gpt-oss-120B-eagle3.json b/SpecForge-ext/configs/gpt-oss-120B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..f4b36c7687620c95e90b4ec43ee8a53763826954 --- /dev/null +++ b/SpecForge-ext/configs/gpt-oss-120B-eagle3.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 17, + 33 + ] + }, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2880, + "initializer_range": 0.02, + "intermediate_size": 17280, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.3", + "use_cache": true, + "vocab_size": 201088, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/gpt-oss-20B-eagle3.json b/SpecForge-ext/configs/gpt-oss-20B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..e1d4b257d9644032488a31a67aca8719ffdbe33e --- /dev/null +++ b/SpecForge-ext/configs/gpt-oss-20B-eagle3.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 11, + 21 + ] + }, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2880, + "initializer_range": 0.02, + "intermediate_size": 17280, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.3", + "use_cache": true, + "vocab_size": 201088, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/ling-flash-2.0-eagle3.json b/SpecForge-ext/configs/ling-flash-2.0-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..0a9bea37c06ae29010eade7cd4b70cdf4e9e0316 --- /dev/null +++ b/SpecForge-ext/configs/ling-flash-2.0-eagle3.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "bos_token_id": 163584, + "eos_token_id": 163585, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.1", + "use_cache": true, + "vocab_size": 157184, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/llama3-70B-ealge3.json b/SpecForge-ext/configs/llama3-70B-ealge3.json new file mode 100644 index 0000000000000000000000000000000000000000..20d04f4d0dc09fe2894a7a35673b3a8afdaa8e32 --- /dev/null +++ b/SpecForge-ext/configs/llama3-70B-ealge3.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "bos_token_id": 128000, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 4096, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": true, + "vocab_size": 128256, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/llama3-8B-eagle3.json b/SpecForge-ext/configs/llama3-8B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..775ad6afee3c43946742b823b8f4e3d48af68b3c --- /dev/null +++ b/SpecForge-ext/configs/llama3-8B-eagle3.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": true, + "vocab_size": 128256, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/llama4-scout-17B-16E-eagle3.json b/SpecForge-ext/configs/llama4-scout-17B-16E-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..9c2bb5a81a3b5452836b0c6dcf1ba29e4ecc64e5 --- /dev/null +++ b/SpecForge-ext/configs/llama4-scout-17B-16E-eagle3.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.3", + "use_cache": true, + "vocab_size": 202048, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/longcat-flash-dflash.json b/SpecForge-ext/configs/longcat-flash-dflash.json new file mode 100644 index 0000000000000000000000000000000000000000..6b3d34d78bd5ffe5663b0025c8949284620887a8 --- /dev/null +++ b/SpecForge-ext/configs/longcat-flash-dflash.json @@ -0,0 +1,41 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoModel": "modeling_dflash.DFlashDraftModel" + }, + "block_size": 16, + "bos_token_id": 1, + "dtype": "bfloat16", + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 131072 + } diff --git a/SpecForge-ext/configs/longcat-flash-eagle3.json b/SpecForge-ext/configs/longcat-flash-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..7b3b921a22378353f010d1ee1ba03ec44610eb75 --- /dev/null +++ b/SpecForge-ext/configs/longcat-flash-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 131072, + "max_window_layers": 48, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads":16, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 131072, + "draft_vocab_size": 131072 + } diff --git a/SpecForge-ext/configs/phi4-eagle3.json b/SpecForge-ext/configs/phi4-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..05456a0d239653cdc898413860c6822d8a7cdec5 --- /dev/null +++ b/SpecForge-ext/configs/phi4-eagle3.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 100257, + "eos_token_id": 100257, + "pad_token_id": 100257, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 16384, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 1, + "num_key_value_heads": 10, + "rms_norm_eps": 1e-05, + "rope_theta": 250000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.47.0", + "use_cache": true, + "vocab_size": 100352, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen2.5-7b-eagle3.json b/SpecForge-ext/configs/qwen2.5-7b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..f16f6b8d07b120734f1eafd8c2e7881e424a57a1 --- /dev/null +++ b/SpecForge-ext/configs/qwen2.5-7b-eagle3.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "llama", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": 131072, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 16000 +} diff --git a/SpecForge-ext/configs/qwen2.5-vl-32b-eagle3.json b/SpecForge-ext/configs/qwen2.5-vl-32b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..76aa04cdf7cdf706443308f72f5e487cf6f510ff --- /dev/null +++ b/SpecForge-ext/configs/qwen2.5-vl-32b-eagle3.json @@ -0,0 +1,40 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 8192, + "max_window_layers": 28, + "model_type": "llama", + "target_model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16, + 24, + 24 + ] + }, + "rope_theta": 1000000, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 + } diff --git a/SpecForge-ext/configs/qwen3-235B-A22B-eagle3.json b/SpecForge-ext/configs/qwen3-235B-A22B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..8e28c04a18a851c968252b1691b89dcdcff598b9 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-235B-A22B-eagle3.json @@ -0,0 +1,36 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 46, + 90 + ], + "use_aux_hidden_state": true + }, + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "draft_vocab_size": 32000, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 24576, + "max_position_embeddings": 40960, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "rope_scaling": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "vocab_size": 151936 +} diff --git a/SpecForge-ext/configs/qwen3-30B-A3B-eagle3.json b/SpecForge-ext/configs/qwen3-30B-A3B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..558cb18043a5bd182497536203de90a4a7672f35 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-30B-A3B-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 2048, + "max_window_layers": 48, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads":4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-32b-eagle3.json b/SpecForge-ext/configs/qwen3-32b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..cf128d9fb451833207c0a4293554357f324aea8c --- /dev/null +++ b/SpecForge-ext/configs/qwen3-32b-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 25600, + "max_position_embeddings": 40960, + "max_window_layers": 64, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-4b-eagle3.json b/SpecForge-ext/configs/qwen3-4b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..41ae128fdcd532f1e31c6251819d29aedfa9d3e6 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-4b-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-8b-dflash.json b/SpecForge-ext/configs/qwen3-8b-dflash.json new file mode 100644 index 0000000000000000000000000000000000000000..8e75aa2148481b08625a9ba98f192d6cfcb8cd33 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-8b-dflash.json @@ -0,0 +1,41 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoModel": "modeling_dflash.DFlashDraftModel" + }, + "block_size": 16, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} diff --git a/SpecForge-ext/configs/qwen3-8b-eagle3.json b/SpecForge-ext/configs/qwen3-8b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..b1fa44906d6decad8ccee5c8296699b1db5750f1 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-8b-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads":8 , + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-8b-qwen3eagle-5layer.json b/SpecForge-ext/configs/qwen3-8b-qwen3eagle-5layer.json new file mode 100644 index 0000000000000000000000000000000000000000..0f88baa3e88573155958c48af98d8e2620ea6e02 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-8b-qwen3eagle-5layer.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-coder-30B-A3B-instruct-eagle3.json b/SpecForge-ext/configs/qwen3-coder-30B-A3B-instruct-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..f296c237973a83f40f4540a97bbc193e2593bb44 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-coder-30B-A3B-instruct-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 2048, + "max_window_layers": 48, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-coder-480B-A35B-instruct-eagle3.json b/SpecForge-ext/configs/qwen3-coder-480B-A35B-instruct-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..2f27c80cc017e811f8846f2161a977725e669086 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-coder-480B-A35B-instruct-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 262144, + "max_window_layers": 62, + "model_type": "llama", + "num_attention_heads": 96, + "num_hidden_layers": 1, + "num_key_value_heads":8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/configs/qwen3-next-80b-a3b-eagle3.json b/SpecForge-ext/configs/qwen3-next-80b-a3b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..e94a2ea3407d784ee9fbd4b6a15b96cd7cadfec8 --- /dev/null +++ b/SpecForge-ext/configs/qwen3-next-80b-a3b-eagle3.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 262144, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 1, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000000, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.0.dev0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 + } diff --git a/SpecForge-ext/configs/qwq-32B-eagle3.json b/SpecForge-ext/configs/qwq-32B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..8f7d7908d5433c886a1725c1ec456f032ba80202 --- /dev/null +++ b/SpecForge-ext/configs/qwq-32B-eagle3.json @@ -0,0 +1,28 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 27648, + "max_position_embeddings": 40960, + "max_window_layers": 64, + "model_type": "qwen2", + "num_attention_heads": 40, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 +} diff --git a/SpecForge-ext/datasets/README.md b/SpecForge-ext/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ddbef6d72d759dc06d8e59c15ef73c0ec29c204 --- /dev/null +++ b/SpecForge-ext/datasets/README.md @@ -0,0 +1,5 @@ +## Store Comprehensive Datasets Download Scripts + +| DatasetName | Github | Huggingface | command | +| -------- | -------- | -------- | -------- | +| ALLaVA-4V | [link](https://github.com/FreedomIntelligence/ALLaVA) | [link](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V) | download_laion.sh | diff --git a/SpecForge-ext/datasets/download_laion.sh b/SpecForge-ext/datasets/download_laion.sh new file mode 100644 index 0000000000000000000000000000000000000000..a64d061ebb5de06b2e87cfc3bcd2b38508b7009e --- /dev/null +++ b/SpecForge-ext/datasets/download_laion.sh @@ -0,0 +1,36 @@ + + +laion_root="allava_laion" + +mkdir $laion_root +cd $laion_root + + +# 1. download annotation files +## 1.1 caption +wget -c -O ALLaVA-Caption-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Caption-LAION-4V.json?download=true + +## 1.2 instruction +wget -c -O ALLaVA-Instruct-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Instruct-LAION-4V.json?download=true + + +# 2. download and upzip images +mkdir image_chunks + +## 2.1 download +for ((i=0; i<10; i++)) +do + wget -c -O image_chunks/images_$i.zip https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/image_chunks/images_$i.zip?download=true & +done + +mkdir -p images/ +wait + +## 2.2 unzip +for ((i=0; i<10; i++)) +do + unzip -j -o image_chunks/images_$i.zip -d images/ & # wait patiently, it takes a while... +done + +wait +echo "All done!" diff --git a/SpecForge-ext/docs/Makefile b/SpecForge-ext/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..6b8792c428564ace773add1f751f7c2471a8fe83 --- /dev/null +++ b/SpecForge-ext/docs/Makefile @@ -0,0 +1,58 @@ +# Minimal Makefile for Sphinx documentation +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SPHINXAUTOBUILD ?= sphinx-autobuild +SOURCEDIR = . +BUILDDIR = _build +PORT ?= 8003 + +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @echo "" + @echo "Additional targets:" + @echo " serve to build and serve documentation with auto-build and live reload" + +# Compile Notebook files and record execution time +compile: + @set -e; \ + echo "Starting Notebook compilation..."; \ + mkdir -p logs; \ + echo "Notebook execution timings:" > logs/timing.log; \ + START_TOTAL=$$(date +%s); \ + find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print0 | \ + parallel -0 -j3 --halt soon,fail=1 ' \ + NB_NAME=$$(basename {}); \ + START_TIME=$$(date +%s); \ + retry --delay=0 --times=2 -- \ + jupyter nbconvert --to notebook --execute --inplace "{}" \ + --ExecutePreprocessor.timeout=600 \ + --ExecutePreprocessor.kernel_name=python3; \ + RET_CODE=$$?; \ + END_TIME=$$(date +%s); \ + ELAPSED_TIME=$$((END_TIME - START_TIME)); \ + echo "$${NB_NAME}: $${ELAPSED_TIME}s" >> logs/timing.log; \ + exit $$RET_CODE' || exit 1; \ + END_TOTAL=$$(date +%s); \ + TOTAL_ELAPSED=$$((END_TOTAL - START_TOTAL)); \ + echo "---------------------------------" >> logs/timing.log; \ + echo "Total execution time: $${TOTAL_ELAPSED}s" >> logs/timing.log; \ + echo "All Notebook execution timings:" && cat logs/timing.log + +# Serve documentation with auto-build and live reload +serve: + @echo "Starting auto-build server at http://0.0.0.0:$(PORT)" + @$(SPHINXAUTOBUILD) "$(SOURCEDIR)" "$(BUILDDIR)/html" \ + --host 0.0.0.0 \ + --port $(PORT) \ + --watch $(SOURCEDIR) \ + --re-ignore ".*\.(ipynb_checkpoints|pyc|pyo|pyd|git)" + +.PHONY: help Makefile compile clean serve + +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + find . -name "*.ipynb" -exec nbstripout {} \; + rm -rf $(BUILDDIR) + rm -rf logs diff --git a/SpecForge-ext/docs/README.md b/SpecForge-ext/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..592f0e51a0f9be1b4aa959867fb526ed4003c149 --- /dev/null +++ b/SpecForge-ext/docs/README.md @@ -0,0 +1,55 @@ +# SpecForge Documentation + +We recommend new contributors to start from writing documentation, which helps you quickly understand the SpecForge codebase. +Most documentation files are located under the `docs/` folder. + +## Docs Workflow + +### Install Dependency + +```bash +apt-get update && apt-get install -y pandoc parallel retry +pip install -r requirements.txt +``` + +### Update Documentation + +Update your Jupyter notebooks in the appropriate subdirectories under `docs/`. If you add new files, remember to update `index.rst` (or relevant `.rst` files) accordingly. + +- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request. + +```bash +# 1) Compile all Jupyter notebooks +make compile # This step can take a long time (10+ mins). You can consider skipping this step if you can make sure your added files are correct. +make html + +# 2) Compile and Preview documentation locally with auto-build +# This will automatically rebuild docs when files change +# Open your browser at the displayed port to view the docs +bash serve.sh + +# 2a) Alternative ways to serve documentation +# Directly use make serve +make serve +# With custom port +PORT=8080 make serve + +# 3) Clean notebook outputs +# nbstripout removes notebook outputs so your PR stays clean +pip install nbstripout +find . -name '*.ipynb' -exec nbstripout {} \; + +# 4) Pre-commit checks and create a PR +# After these checks pass, push your changes and open a PR on your branch +pre-commit run --all-files +``` +--- + +## Documentation Style Guidelines + +- For common functionalities, we prefer **Jupyter Notebooks** over Markdown so that all examples can be executed and validated by our docs CI pipeline. For complex features (e.g., distributed serving), Markdown is preferred. +- Keep in mind the documentation execution time when writing interactive Jupyter notebooks. Each interactive notebook will be run and compiled against every commit to ensure they are runnable, so it is important to apply some tips to reduce the documentation compilation time: + - Use small models (e.g., `qwen/qwen2.5-0.5b-instruct`) for most cases to reduce server launch time. + - Reuse the launched server as much as possible to reduce server launch time. +- Do not use absolute links (e.g., `https://docs.sglang.ai/get_started/install.html`). Always prefer relative links (e.g., `../get_started/install.md`). +- Follow the existing examples to learn how to launch a server, send a query and other common styles. diff --git a/SpecForge-ext/docs/conf.py b/SpecForge-ext/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fef2396e931693259e82aee2e78cdb77d6c256 --- /dev/null +++ b/SpecForge-ext/docs/conf.py @@ -0,0 +1,188 @@ +import os +import sys +from datetime import datetime +from pathlib import Path + +sys.path.insert(0, os.path.abspath("../..")) + +DOCS_PATH = Path(__file__).parent +ROOT_PATH = DOCS_PATH.parent + +version_file = ROOT_PATH.joinpath("version.txt") +with open(version_file, "r") as f: + __version__ = f.read().strip() + +project = "SGLang" +copyright = f"2025-{datetime.now().year}, SpecForge" +author = "SpecForge Team" + +version = __version__ +release = __version__ + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", + "sphinx.ext.intersphinx", + "sphinx_tabs.tabs", + "myst_parser", + "sphinx_copybutton", + "sphinxcontrib.mermaid", + "nbsphinx", + "sphinx.ext.mathjax", +] + +nbsphinx_allow_errors = True +nbsphinx_execute = "never" + +autosectionlabel_prefix_document = True +nbsphinx_allow_directives = True + + +myst_enable_extensions = [ + "dollarmath", + "amsmath", + "deflist", + "colon_fence", + "html_image", + "substitution", +] + +myst_heading_anchors = 5 + +nbsphinx_kernel_name = "python3" +nbsphinx_execute_arguments = [ + "--InlineBackend.figure_formats={'svg', 'pdf'}", + "--InlineBackend.rc={'figure.dpi': 96}", +] + + +nb_render_priority = { + "html": ( + "application/vnd.jupyter.widget-view+json", + "application/javascript", + "text/html", + "image/svg+xml", + "image/png", + "image/jpeg", + "text/markdown", + "text/latex", + "text/plain", + ) +} + +myst_ref_domains = ["std", "py"] + +templates_path = ["_templates"] + +source_suffix = { + ".rst": "restructuredtext", + ".md": "markdown", +} + +master_doc = "index" + +language = "en" + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +pygments_style = "sphinx" + +html_theme = "sphinx_book_theme" +html_logo = ROOT_PATH.joinpath("assets/logo.png").as_posix() +html_favicon = ROOT_PATH.joinpath("assets/logo.ico").as_posix() +html_title = project +html_copy_source = True +html_last_updated_fmt = "" + +html_theme_options = { + "repository_url": "https://github.com/sgl-project/sgl-project.github.io", + "repository_branch": "main", + "show_navbar_depth": 3, + "max_navbar_depth": 4, + "collapse_navbar": True, + "use_edit_page_button": True, + "use_source_button": True, + "use_issues_button": True, + "use_repository_button": True, + "use_download_button": True, + "use_sidenotes": True, + "show_toc_level": 2, +} + +html_context = { + "display_github": True, + "github_user": "sgl-project", + "github_repo": "sgl-project.github.io", + "github_version": "main", + "conf_py_path": "/docs/", +} + +html_static_path = ["_static", "spec_bundle/public"] +html_css_files = ["css/custom_log.css"] + + +def setup(app): + app.add_css_file("css/custom_log.css") + + +htmlhelp_basename = "sglangdoc" + +latex_elements = {} + +latex_documents = [ + (master_doc, "sglang.tex", "sglang Documentation", "SGLang Team", "manual"), +] + +man_pages = [(master_doc, "sglang", "sglang Documentation", [author], 1)] + +texinfo_documents = [ + ( + master_doc, + "sglang", + "sglang Documentation", + author, + "sglang", + "One line description of project.", + "Miscellaneous", + ), +] + +epub_title = project + +epub_exclude_files = ["search.html"] + +copybutton_prompt_text = r">>> |\.\.\. " +copybutton_prompt_is_regexp = True + +autodoc_preserve_defaults = True +navigation_with_keys = False + +autodoc_mock_imports = [ + "torch", + "transformers", + "triton", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.12", None), + "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest", None), + "pillow": ("https://pillow.readthedocs.io/en/stable", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable", None), +} + +html_theme = "sphinx_book_theme" + + +nbsphinx_prolog = """ +.. raw:: html + + +""" diff --git a/SpecForge-ext/docs/deploy.py b/SpecForge-ext/docs/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..75b7ea7f23dce0a5deb17c28d78b5cc59833a4d6 --- /dev/null +++ b/SpecForge-ext/docs/deploy.py @@ -0,0 +1,22 @@ +# Deploy the documents + +import os +from datetime import datetime + + +def run_cmd(cmd): + print(cmd) + os.system(cmd) + + +run_cmd("cd $DOC_SITE_PATH; git pull") + +# (Optional) Remove old files +# run_cmd("rm -rf $ALPA_SITE_PATH/*") + +run_cmd("cp -r _build/html/* $DOC_SITE_PATH") + +cmd_message = f"Update {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" +run_cmd( + f"cd $DOC_SITE_PATH; git add .; git commit -m '{cmd_message}'; git push origin main" +) diff --git a/SpecForge-ext/docs/index.rst b/SpecForge-ext/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..bc2c694798793eddd894f5bd94fde539b9fb06b8 --- /dev/null +++ b/SpecForge-ext/docs/index.rst @@ -0,0 +1,53 @@ +SpecForge Documentation +======================= + +SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference. + + +.. toctree:: + :maxdepth: 1 + :caption: Get Started + + get_started/installation.md + get_started/about.md + +.. toctree:: + :maxdepth: 1 + :caption: Concepts + + concepts/speculative_decoding.md + concepts/EAGLE3.md + + +.. toctree:: + :maxdepth: 1 + :caption: Basic Usage + + basic_usage/data_preparation.md + basic_usage/training.md + +.. toctree:: + :maxdepth: 1 + :caption: Advanced Features + + advanced_features/customization.md + +.. toctree:: + :maxdepth: 1 + :caption: Community Resources + + community_resources/specbundle.md + community_resources/dashboard.md + +.. toctree:: + :maxdepth: 1 + :caption: Examples + + examples/llama3-eagle3-online.md + examples/llama3-eagle3-offline.md + +.. toctree:: + :maxdepth: 1 + :caption: Benchmarks + + benchmarks/benchmark.md diff --git a/SpecForge-ext/docs/requirements.txt b/SpecForge-ext/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1a7e5d4eba2f265cb2dce4eff31d770eb71125f3 --- /dev/null +++ b/SpecForge-ext/docs/requirements.txt @@ -0,0 +1,20 @@ +ipykernel +ipywidgets +jupyter_client +markdown>=3.4.0 +matplotlib +myst-parser +nbconvert +nbsphinx +pandoc +pillow +pydantic +sphinx +sphinx-book-theme +sphinx-copybutton +sphinx-tabs +nbstripout +sphinxcontrib-mermaid +urllib3<2.0.0 +gguf>=0.10.0 +sphinx-autobuild diff --git a/SpecForge-ext/docs/serve.sh b/SpecForge-ext/docs/serve.sh new file mode 100644 index 0000000000000000000000000000000000000000..049f767cf497a5fd92b1dac0af2fc13fdcf3fa69 --- /dev/null +++ b/SpecForge-ext/docs/serve.sh @@ -0,0 +1,3 @@ +# Clean and serve documentation with auto-build +make clean +make serve diff --git a/SpecForge-ext/examples/run_deepseek_v3_671b_eagle3_online.sh b/SpecForge-ext/examples/run_deepseek_v3_671b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..2eb2769f9b5582a811a83305b72ea67bef5b514b --- /dev/null +++ b/SpecForge-ext/examples/run_deepseek_v3_671b_eagle3_online.sh @@ -0,0 +1,29 @@ + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v3 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# train eagle3 online +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path deepseek-ai/DeepSeek-V3 \ + --draft-model-config $ROOT_DIR/configs/deepseek-v3-671b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/deepseek-v3-671B-eagle3-perfect-blend-online \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template deepseek-v3 \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 \ + --sglang-mem-fraction-static 0.75 diff --git a/SpecForge-ext/examples/run_qwen3_30b_a3b_eagle3_online.sh b/SpecForge-ext/examples/run_qwen3_30b_a3b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..29b5ac167b044ea321e8f25a5f6f0f5b088dc90c --- /dev/null +++ b/SpecForge-ext/examples/run_qwen3_30b_a3b_eagle3_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp4/tp8 train eagle3 for Qwen3-30B-A3B +NUM_GPUS=${1:-4} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-30b-a3b-instruct-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/SpecForge-ext/examples/run_qwq_eagle3_online.sh b/SpecForge-ext/examples/run_qwq_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..2b2fae6f19bba8a55a93e134f4cc848e18767d02 --- /dev/null +++ b/SpecForge-ext/examples/run_qwq_eagle3_online.sh @@ -0,0 +1,28 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for qwq-32b +NUM_GPUS=${1:-4} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/QwQ-32B \ + --draft-model-config $ROOT_DIR/configs/qwq-32B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwq-32b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/SpecForge-ext/logs/baseline_gsm8k_20260213_100853.log b/SpecForge-ext/logs/baseline_gsm8k_20260213_100853.log new file mode 100644 index 0000000000000000000000000000000000000000..186b22b1934582a35a3c0e5a3f4dfde886d86358 --- /dev/null +++ b/SpecForge-ext/logs/baseline_gsm8k_20260213_100853.log @@ -0,0 +1,5 @@ +WARNING:sglang.srt.server_args:Attention backend not explicitly specified. Use fa3 backend by default. +Running benchmark gsm8k with 100 prompts, batch size 1, steps None, topk None, num_draft_tokens None, subset None +Loading GSM8K data from local: /workspace/hanrui/datasets/gsm8k/test.jsonl + 0%| | 0/100 [00:00 None: + parser.add_argument( + "--report-to", + type=str, + default="none", + choices=["wandb", "tensorboard", "swanlab", "mlflow", "none"], + help="The integration to report results and logs to.", + ) + # wandb-specific args + parser.add_argument("--wandb-project", type=str, default=None) + parser.add_argument("--wandb-name", type=str, default=None) + parser.add_argument("--wandb-key", type=str, default=None, help="W&B API key.") + # swanlab-specific args + parser.add_argument( + "--swanlab-project", + type=str, + default=None, + help="The project name for swanlab.", + ) + parser.add_argument( + "--swanlab-name", + type=str, + default=None, + help="The experiment name for swanlab.", + ) + parser.add_argument( + "--swanlab-key", + type=str, + default=None, + help="The API key for swanlab non-interactive login.", + ) + # mlflow-specific args + parser.add_argument( + "--mlflow-tracking-uri", + type=str, + default=None, + help="The MLflow tracking URI. If not set, uses MLFLOW_TRACKING_URI environment variable or defaults to local './mlruns'.", + ) + parser.add_argument( + "--mlflow-experiment-name", + type=str, + default=None, + help="The MLflow experiment name. If not set, uses MLFLOW_EXPERIMENT_NAME environment variable.", + ) + parser.add_argument( + "--mlflow-run-name", + type=str, + default=None, + help="The MLflow run name. If not set, MLflow will auto-generate one.", + ) + + +@dataclass +class SGLangBackendArgs: + sglang_attention_backend: str = "fa3" + sglang_mem_fraction_static: float = 0.4 + sglang_context_length: int = None + sglang_enable_nccl_nvls: bool = False + sglang_enable_symm_mem: bool = False + sglang_enable_torch_compile: bool = True + sglang_enable_dp_attention: bool = False + sglang_enable_dp_lm_head: bool = False + sglang_enable_piecewise_cuda_graph: bool = False + sglang_piecewise_cuda_graph_max_tokens: int = 4096 + sglang_piecewise_cuda_graph_tokens: List[int] = None + sglang_ep_size: int = 1 + sglang_max_running_requests: int = None # assign based on batch size + sglang_max_total_tokens: int = None # assign based on batch size and seq length + + @staticmethod + def add_args(parser: argparse.ArgumentParser) -> None: + # sglang arguments + parser.add_argument( + "--sglang-attention-backend", + type=str, + default="flashinfer", + choices=ATTENTION_BACKEND_CHOICES, + help="The attention backend of SGLang backend", + ) + parser.add_argument( + "--sglang-mem-fraction-static", + type=float, + default=0.4, + help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", + ) + parser.add_argument( + "--sglang-context-length", + type=int, + default=None, + help="The context length of the SGLang backend", + ) + parser.add_argument( + "--sglang-enable-nccl-nvls", + action="store_true", + help="Enable NCCL NVLS for prefill heavy requests when available for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-symm-mem", + action="store_true", + help="Enable NCCL symmetric memory for fast collectives for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-torch-compile", + action="store_true", + help="Optimize the model with torch.compile for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-dp-attention", + action="store_true", + help="Enable DP attention for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-dp-lm-head", + action="store_true", + help="Enable piecewise CUDA graph for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-piecewise-cuda-graph", + action="store_true", + help="Enable piecewise CUDA graph for SGLang backend's prefill", + ) + parser.add_argument( + "--sglang-piecewise-cuda-graph-max-tokens", + type=int, + default=4096, + help="Set the max tokens for piecewise CUDA graph for SGLang backend", + ) + parser.add_argument( + "--sglang-piecewise-cuda-graph-tokens", + type=int, + nargs="+", + default=None, + help="Set the list of tokens when using piecewise cuda graph for SGLang backend", + ) + parser.add_argument( + "--sglang-ep-size", + type=int, + default=1, + help="The ep size of the SGLang backend", + ) + + @staticmethod + def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": + return SGLangBackendArgs( + sglang_attention_backend=args.sglang_attention_backend, + sglang_mem_fraction_static=args.sglang_mem_fraction_static, + sglang_context_length=args.sglang_context_length, + sglang_enable_nccl_nvls=args.sglang_enable_nccl_nvls, + sglang_enable_symm_mem=args.sglang_enable_symm_mem, + sglang_enable_torch_compile=args.sglang_enable_torch_compile, + sglang_enable_dp_attention=args.sglang_enable_dp_attention, + sglang_enable_dp_lm_head=args.sglang_enable_dp_lm_head, + sglang_enable_piecewise_cuda_graph=args.sglang_enable_piecewise_cuda_graph, + sglang_piecewise_cuda_graph_max_tokens=args.sglang_piecewise_cuda_graph_max_tokens, + sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens, + sglang_ep_size=args.sglang_ep_size, + sglang_max_running_requests=( + args.target_batch_size if hasattr(args, "target_batch_size") else None + ), + sglang_max_total_tokens=( + args.target_batch_size * args.max_length + if hasattr(args, "target_batch_size") and hasattr(args, "max_length") + else None + ), + ) + + def to_kwargs(self) -> Dict[str, Any]: + return dict( + attention_backend=self.sglang_attention_backend, + mem_fraction_static=self.sglang_mem_fraction_static, + context_length=self.sglang_context_length, + enable_nccl_nvls=self.sglang_enable_nccl_nvls, + enable_symm_mem=self.sglang_enable_symm_mem, + enable_torch_compile=self.sglang_enable_torch_compile, + enable_dp_attention=self.sglang_enable_dp_attention, + enable_dp_lm_head=self.sglang_enable_dp_lm_head, + enable_piecewise_cuda_graph=self.sglang_enable_piecewise_cuda_graph, + piecewise_cuda_graph_max_tokens=self.sglang_piecewise_cuda_graph_max_tokens, + piecewise_cuda_graph_tokens=self.sglang_piecewise_cuda_graph_tokens, + ep_size=self.sglang_ep_size, + max_running_requests=self.sglang_max_running_requests, + max_total_tokens=self.sglang_max_total_tokens, + ) diff --git a/SpecForge-ext/specforge/distributed.py b/SpecForge-ext/specforge/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5e882c4d69bc2cf8e03afe4fc05f3d60bdc3c6 --- /dev/null +++ b/SpecForge-ext/specforge/distributed.py @@ -0,0 +1,245 @@ +from datetime import timedelta +from typing import Any, Optional + +import torch +import torch.distributed as dist +from yunchang.globals import PROCESS_GROUP, set_seq_parallel_pg + +from specforge.utils import print_with_rank + +_DEVICE_MESH = None +_TP_DEVICE_MESH = None +_TP_GROUP = None +_DP_DEVICE_MESH = None +_DP_GROUP = None +_DRAFT_DP_GROUP = None +_DRAFT_SP_GROUP = None +_SP_ULYSSES_GROUP = None +_SP_RING_GROUP = None + + +def get_tp_group(): + global _TP_GROUP + return _TP_GROUP + + +def get_dp_group(): + global _DP_GROUP + return _DP_GROUP + + +def get_draft_dp_group(): + global _DRAFT_DP_GROUP + return _DRAFT_DP_GROUP + + +def get_draft_sp_group(): + global _DRAFT_SP_GROUP + return _DRAFT_SP_GROUP + + +def get_device_mesh(): + global _DEVICE_MESH + return _DEVICE_MESH + + +def get_tp_device_mesh(): + global _TP_DEVICE_MESH + return _TP_DEVICE_MESH + + +def get_dp_device_mesh(): + global _DP_DEVICE_MESH + return _DP_DEVICE_MESH + + +def get_sp_ulysses_group(): + global _SP_ULYSSES_GROUP + return _SP_ULYSSES_GROUP + + +def get_sp_ring_group(): + global _SP_RING_GROUP + return _SP_RING_GROUP + + +def init_distributed( + timeout: int = 10, tp_size: int = 1, sp_ulysses_size: int = 1, sp_ring_size: int = 1 +): + """Initialize distributed training. + + Args: + timeout(int): Timeout for collective communication in minutes + tp_size(int): The degree of tensor parallelism + """ + dist.init_process_group(backend="nccl", timeout=timedelta(minutes=timeout)) + local_rank = dist.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(local_rank) + print_with_rank(f"bind to device {local_rank}") + + world_size = dist.get_world_size() + dp_size = world_size // tp_size + assert ( + world_size == tp_size * dp_size + ), f"world size must be divisible by tp size, now {world_size=}, {(tp_size * dp_size)=} " + + device_mesh = dist.device_mesh.init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + + assert ( + world_size % (sp_ulysses_size * sp_ring_size) == 0 + ), f"World size ({world_size}) cannot be evenly divided by total SP size ({sp_ulysses_size*sp_ring_size})" + + draft_dp_size = world_size // (sp_ulysses_size * sp_ring_size) + draft_device_mesh = dist.device_mesh.init_device_mesh( + "cuda", + (draft_dp_size, sp_ulysses_size * sp_ring_size), + mesh_dim_names=("draft_dp", "sp"), + ) + set_seq_parallel_pg(sp_ulysses_size, sp_ring_size, dist.get_rank(), world_size) + + print_with_rank(f"device mesh: {device_mesh}") + tp_group = device_mesh.get_group("tp") + dp_group = device_mesh.get_group("dp") + + sp_ulysses_group = PROCESS_GROUP.ULYSSES_PG + sp_ring_group = PROCESS_GROUP.RING_PG + # we need to create a 1D submesh + tp_device_mesh = dist.DeviceMesh.from_group(tp_group, device_type="cuda") + + global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH, _DP_DEVICE_MESH, _SP_RING_GROUP, _SP_ULYSSES_GROUP, _DRAFT_DP_GROUP, _DRAFT_SP_GROUP + _DEVICE_MESH = device_mesh + _TP_GROUP = tp_group + _TP_DEVICE_MESH = tp_device_mesh + _SP_ULYSSES_GROUP = sp_ulysses_group + _SP_RING_GROUP = sp_ring_group + _DP_GROUP = dp_group + _DRAFT_DP_GROUP = draft_device_mesh.get_group("draft_dp") + _DRAFT_SP_GROUP = draft_device_mesh.get_group("sp") + _DP_DEVICE_MESH = dist.DeviceMesh.from_group(dp_group, device_type="cuda") + + +def destroy_distributed(): + global _TP_GROUP, _DP_GROUP, _SP_ULYSSES_GROUP, _SP_RING_GROUP, _DRAFT_DP_GROUP + dist.destroy_process_group(_TP_GROUP) + dist.destroy_process_group(_DP_GROUP) + dist.destroy_process_group(_SP_ULYSSES_GROUP) + dist.destroy_process_group(_SP_RING_GROUP) + dist.destroy_process_group(_DRAFT_DP_GROUP) + dist.destroy_process_group(_DRAFT_SP_GROUP) + dist.destroy_process_group() + + +def shard_tensor( + tensor: torch.Tensor, process_group: dist.ProcessGroup = None, dim: int = -1 +) -> torch.Tensor: + rank = dist.get_rank(process_group) + size = dist.get_world_size(process_group) + return tensor.chunk(size, dim=dim)[rank].contiguous() + + +def gather_tensor( + tensor: torch.Tensor, process_group: dist.ProcessGroup = None, dim: int = -1 +) -> torch.Tensor: + size = dist.get_world_size(process_group) + obj_list = [torch.empty_like(tensor) for _ in range(size)] + dist.all_gather(obj_list, tensor, group=process_group) + gather_tensor = torch.cat(obj_list, dim=dim) + return gather_tensor + + +def all_gather_tensor( + local_tensor: torch.Tensor, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +): + sp_world_size = dist.get_world_size(group=group) + output_shape = list(local_tensor.shape) + output_shape[0] = output_shape[0] * sp_world_size + output = torch.empty( + output_shape, dtype=local_tensor.dtype, device=local_tensor.device + ) + dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) + return output + + +# Adapted from https://github.com/volcengine/verl/blob/a0e8e4472b8b472409defb0c8fcc5162301450af/verl/utils/ulysses.py#L194 +class Gather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_tensor: torch.Tensor, + gather_dim: int, + grad_scaler: bool = True, + async_op=False, + ) -> torch.Tensor: + ctx.group = group + ctx.gather_dim = gather_dim + ctx.grad_scaler = grad_scaler + ctx.async_op = async_op + + sp_world_size = dist.get_world_size(group=group) + ctx.sp_world_size = sp_world_size + + sp_rank = dist.get_rank(group=group) + ctx.sp_rank = sp_rank + + local_shape = list(local_tensor.size()) + split_size = local_shape[0] + part_size = local_shape[gather_dim] # store original size + ctx.part_size = part_size + + output = all_gather_tensor(local_tensor, group, async_op) + return torch.cat(output.split(split_size, dim=0), dim=gather_dim) + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Any: + if ctx.grad_scaler: + grad_output = grad_output * ctx.sp_world_size + return ( + None, + grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ + ctx.sp_rank + ].contiguous(), + None, + None, + None, + None, + ) + + +def gather_outputs_and_unpad( + x: torch.Tensor, + gather_dim: int, + grad_scaler: bool = True, + group: Optional[dist.ProcessGroup] = None, +): + """ + Gather a tensor across a process group and optionally unpad its padded elements. + + Args: + x (Tensor): Input tensor to gather. + gather_dim (int): Dimension along which to gather across ranks. + grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True. + group (ProcessGroup, optional): Process group for gathering. If None, uses + `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged. + + Returns: + Tensor: The gathered tensor, with padding removed if requested. + """ + if not group: + group = get_draft_sp_group() + if torch.distributed.get_world_size(group) == 1: + return x + x = Gather.apply(group, x, gather_dim, grad_scaler) + return x + + +def is_tp_rank_0(): + """Return True if current process is rank 0 in its TP group.""" + tp_group = get_tp_group() + if tp_group is None: + return True + return dist.get_rank(group=tp_group) == 0 diff --git a/SpecForge-ext/specforge/lr_scheduler.py b/SpecForge-ext/specforge/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3276c79806a13c5223f0042ffc3abf6cad52ce --- /dev/null +++ b/SpecForge-ext/specforge/lr_scheduler.py @@ -0,0 +1,260 @@ +from warnings import warn + +from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR +from torch.optim.lr_scheduler import LRScheduler as _LRScheduler + + +class _enable_get_lr_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class TwoStageScheduler(_LRScheduler): + def __init__(self, optimizer, after_scheduler: _LRScheduler, last_epoch=-1): + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = { + key: value for key, value in self.__dict__.items() if key not in "optimizer" + } + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type( + state_dict["after_scheduler"] + ).__name__ + state_dict["after_scheduler_dict"] = state_dict[ + "after_scheduler" + ].state_dict() + del state_dict["after_scheduler"] + else: + raise NotImplementedError() + return state_dict + + def load_state_dict(self, state_dict): + if "after_scheduler_dict" not in state_dict: + warn( + "after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior." + ) + else: + self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"]) + state_dict = { + key: value + for key, value in state_dict.items() + if key not in ("after_scheduler_type", "after_scheduler_dict") + } + super().load_state_dict(state_dict) + + +class DelayerScheduler(TwoStageScheduler): + """Starts with a flat lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau) + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): + if delay_epochs < 0: + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") + self.delay_epochs = delay_epochs + super().__init__(optimizer, after_scheduler, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + + return self.base_lrs + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.delay_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(DelayerScheduler, self).step(epoch) + + +class WarmupScheduler(TwoStageScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): + self.warmup_epochs = int(warmup_epochs) + super().__init__(optimizer, after_scheduler, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + return self.after_scheduler.get_lr() + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) + + +class WarmupDelayerScheduler(TwoStageScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule + until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__( + self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1 + ): + if delay_epochs < 0: + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") + if warmup_epochs < 0: + raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") + self.warmup_epochs = warmup_epochs + self.delay_epochs = delay_epochs + super().__init__(optimizer, after_scheduler, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs + self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + # reset lr to base_lr + for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): + group["lr"] = base_lr + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + elif self.last_epoch >= self.warmup_epochs: + return self.base_lrs + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) + + +class CosineAnnealingLR(_CosineAnnealingLR): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__( + self, + optimizer, + total_steps: int, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs, + ): + super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch) + + +class CosineAnnealingWarmupLR(WarmupScheduler): + """Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + eta_min: float = 0.0, + last_epoch: int = -1, + ): + base_scheduler = _CosineAnnealingLR( + optimizer, + total_steps - warmup_steps, + eta_min=eta_min, + last_epoch=last_epoch, + ) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/SpecForge-ext/specforge/optimizer.py b/SpecForge-ext/specforge/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7bdd3ab8dd9f2960e3612da50469ba13792df83a --- /dev/null +++ b/SpecForge-ext/specforge/optimizer.py @@ -0,0 +1,66 @@ +import torch + +from specforge.lr_scheduler import CosineAnnealingWarmupLR +from specforge.utils import print_on_rank0 + + +class BF16Optimizer: + def __init__( + self, + model, + lr, + weight_decay=0.0, + max_grad_norm=0.5, + total_steps=800_000, + warmup_ratio=0.015, + ): + # TODO: For now, we only support cosine annealing warmup lr scheduler and AdamW optimizer + # TODO: We should make these parameters configurable + # These magic numbers: weight_decay=0.0, max_grad_norm=0.5, total_steps=800k, warmup_steps=12k are copied from + # https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/ds_config.json + self.model = model + self.model_params = [p for p in model.parameters() if p.requires_grad] + self.max_grad_norm = max_grad_norm + self.fp32_params = [ + p.detach().clone().to(torch.float32) for p in self.model_params + ] + for mp in self.fp32_params: + mp.requires_grad = True + self.optimizer = torch.optim.AdamW( + self.fp32_params, lr=lr, weight_decay=weight_decay + ) + self.scheduler = CosineAnnealingWarmupLR( + self.optimizer, + total_steps=total_steps, + warmup_steps=int(warmup_ratio * total_steps), + ) + + def step(self): + with torch.no_grad(): + for p, mp in zip(self.model_params, self.fp32_params): + mp.grad = ( + p.grad.detach().to(torch.float32) if p.grad is not None else None + ) + torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm) + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + with torch.no_grad(): + for p, mp in zip(self.model_params, self.fp32_params): + p.data.copy_(mp.data.to(p.dtype)) + p.grad = None + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) + print_on_rank0("Successfully loaded optimizer state_dict.") + self.scheduler.load_state_dict(state_dict["scheduler_state_dict"]) + print_on_rank0("Successfully loaded scheduler state_dict.") + + def state_dict(self): + return { + "optimizer_state_dict": self.optimizer.state_dict(), + "scheduler_state_dict": self.scheduler.state_dict(), + } + + def get_learning_rate(self): + return self.optimizer.param_groups[0]["lr"] diff --git a/SpecForge-ext/specforge/tracker.py b/SpecForge-ext/specforge/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7498c156b3194c7afcef625801a6d33a41364 --- /dev/null +++ b/SpecForge-ext/specforge/tracker.py @@ -0,0 +1,297 @@ +# tracker.py + +import abc +import netrc +import os +from typing import Any, Dict, Optional + +import torch.distributed as dist + +# --- Lazy Imports --- +# These libraries are imported only when their respective trackers are used. +try: + import wandb +except ImportError: + wandb = None + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + +try: + import swanlab +except ImportError: + swanlab = None + +try: + import mlflow +except ImportError: + mlflow = None + + +# --- End Lazy Imports --- + + +class Tracker(abc.ABC): + """ + Abstract Base Class for experiment trackers. + + Each tracker implementation should handle its own initialization, logging, + and cleanup. It should also provide a class method to validate + command-line arguments before initialization. + """ + + def __init__(self, args, output_dir: str): + self.args = args + self.output_dir = output_dir + self.rank = dist.get_rank() + self.is_initialized = False + + @classmethod + @abc.abstractmethod + def validate_args(cls, parser, args) -> None: + """ + Validate necessary arguments for this tracker. + This method is called during argument parsing. + It should raise an error if required arguments are missing. + """ + + @abc.abstractmethod + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None) -> None: + """ + Log metrics to the tracker. + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Close the tracker and clean up resources. + """ + + +class NoOpTracker(Tracker): + """A tracker that does nothing, for when no tracking is desired.""" + + @classmethod + def validate_args(cls, parser, args): + pass # No arguments to validate + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + self.is_initialized = True # Considered initialized to do nothing + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + pass # Do nothing + + def close(self): + pass # Do nothing + + +class WandbTracker(Tracker): + """Tracks experiments using Weights & Biases.""" + + @classmethod + def validate_args(cls, parser, args): + if wandb is None: + parser.error( + "To use --report-to wandb, you must install wandb: 'pip install wandb'" + ) + + if args.wandb_key is not None: + return + + if "WANDB_API_KEY" in os.environ: + args.wandb_key = os.environ["WANDB_API_KEY"] + return + + try: + netrc_path = os.path.expanduser("~/.netrc") + if os.path.exists(netrc_path): + netrc_file = netrc.netrc(netrc_path) + if "api.wandb.ai" in netrc_file.hosts: + _, _, password = netrc_file.authenticators("api.wandb.ai") + if password: + args.wandb_key = password + return + except (FileNotFoundError, netrc.NetrcParseError): + pass + + if args.wandb_key is None: + parser.error( + "When --report-to is 'wandb', you must provide a wandb API key via one of:\n" + " 1. --wandb-key argument\n" + " 2. WANDB_API_KEY environment variable\n" + " 3. `wandb login` command" + ) + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + wandb.login(key=args.wandb_key) + wandb.init( + project=args.wandb_project, name=args.wandb_name, config=vars(args) + ) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + wandb.log(log_dict, step=step) + + def close(self): + if self.rank == 0 and self.is_initialized and wandb.run: + wandb.finish() + self.is_initialized = False + + +class SwanlabTracker(Tracker): + """Tracks experiments using SwanLab.""" + + @classmethod + def validate_args(cls, parser, args): + if swanlab is None: + parser.error( + "To use --report-to swanlab, you must install swanlab: 'pip install swanlab'" + ) + + if args.swanlab_key is not None: + return + if "SWANLAB_API_KEY" in os.environ: + args.swanlab_key = os.environ["SWANLAB_API_KEY"] + return + # Swanlab can run in anonymous mode if no key is provided in a non-distributed env. + # However, a key is often required for distributed runs to sync correctly. + if ( + dist.is_initialized() + and dist.get_world_size() > 1 + and args.swanlab_key is None + ): + parser.error( + "In a distributed environment, when --report-to is 'swanlab', you must provide a swanlab API key via:\n" + " 1. --swanlab-key argument\n" + " 2. SWANLAB_API_KEY environment variable" + ) + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + if args.swanlab_key: + swanlab.login(api_key=args.swanlab_key) + + swanlog_dir = os.path.join(output_dir, "swanlog") + os.makedirs(swanlog_dir, exist_ok=True) + swanlab.init( + project=args.swanlab_project, + experiment_name=args.swanlab_name, + config=vars(args), + logdir=swanlog_dir, + ) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + swanlab.log(log_dict, step=step) + + def close(self): + if self.rank == 0 and self.is_initialized and swanlab.get_run() is not None: + swanlab.finish() + self.is_initialized = False + + +class TensorboardTracker(Tracker): + """Tracks experiments using TensorBoard.""" + + @classmethod + def validate_args(cls, parser, args): + if SummaryWriter is None: + parser.error( + "To use --report-to tensorboard, you must have tensorboard installed: 'pip install tensorboard'" + ) + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + log_dir = os.path.join(output_dir, "runs") + self.writer = SummaryWriter(log_dir=log_dir) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + for key, value in log_dict.items(): + if isinstance(value, (int, float)): + self.writer.add_scalar(key, value, global_step=step) + + def close(self): + if self.rank == 0 and self.is_initialized: + self.writer.close() + self.is_initialized = False + + +class MLflowTracker(Tracker): + """Tracks experiments using MLflow.""" + + @classmethod + def validate_args(cls, parser, args): + if mlflow is None: + parser.error( + "To use --report-to mlflow, you must install mlflow: 'pip install mlflow'" + ) + # Set tracking URI from environment variable if not explicitly provided + if args.mlflow_tracking_uri is None and "MLFLOW_TRACKING_URI" in os.environ: + args.mlflow_tracking_uri = os.environ["MLFLOW_TRACKING_URI"] + elif args.mlflow_tracking_uri is None: + print( + "Warning: MLflow tracking URI not set. Defaulting to local './mlruns'." + ) + + # Set experiment name from environment variable if not explicitly provided + if ( + args.mlflow_experiment_name is None + and "MLFLOW_EXPERIMENT_NAME" in os.environ + ): + args.mlflow_experiment_name = os.environ["MLFLOW_EXPERIMENT_NAME"] + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + if args.mlflow_tracking_uri: + mlflow.set_tracking_uri(args.mlflow_tracking_uri) + + # This will either use the set URI or the default + mlflow.set_experiment(args.mlflow_experiment_name) + mlflow.start_run(run_name=args.mlflow_run_name) + mlflow.log_params(vars(args)) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + # MLflow's log_metrics takes a dictionary directly + mlflow.log_metrics(log_dict, step=step) + + def close(self): + if self.rank == 0 and self.is_initialized: + mlflow.end_run() + self.is_initialized = False + + +# --- Tracker Factory --- +TRACKER_REGISTRY = { + "wandb": WandbTracker, + "swanlab": SwanlabTracker, + "tensorboard": TensorboardTracker, + "mlflow": MLflowTracker, + "none": NoOpTracker, +} + + +def get_tracker_class(report_to: str) -> Optional[Tracker]: + """Returns the tracker class based on the name.""" + return TRACKER_REGISTRY.get(report_to) + + +def create_tracker(args, output_dir: str) -> Tracker: + """Factory function to create an experiment tracker instance.""" + tracker_class = get_tracker_class(args.report_to) + if not tracker_class: + raise ValueError(f"Unsupported report_to type: {args.report_to}") + return tracker_class(args, output_dir) diff --git a/SpecForge-ext/specforge/utils.py b/SpecForge-ext/specforge/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..59724a82463a64e8937e87da61f6a38179fe4f3b --- /dev/null +++ b/SpecForge-ext/specforge/utils.py @@ -0,0 +1,359 @@ +import json +import logging +import os +import re +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Shard, distribute_tensor +from transformers import AutoConfig, PretrainedConfig + +logger = logging.getLogger(__name__) + + +@contextmanager +def rank_0_priority(): + rank = dist.get_rank() + + if rank == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def default_torch_dtype(dtype: torch.dtype): + current_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(current_dtype) + + +@torch.no_grad() +def padding(tensor, left=True): + zeropadding = torch.zeros_like(tensor[:, -1:]) + if left: + tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1) + else: + tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1) + return tensor + + +def load_config_from_file(config_path: str): + with open(config_path, "r") as f: + config = json.load(f) + + return PretrainedConfig.from_dict(config) + + +def print_with_rank(message): + if dist.is_available() and dist.is_initialized(): + logger.info(f"rank {dist.get_rank()}: {message}") + else: + logger.info(f"non-distributed: {message}") + + +def print_args_with_dots(args): + if dist.get_rank() == 0: + args_dict = vars(args) + max_key_length = max(len(key) for key in args_dict.keys()) + total_width = 50 + + print("\n -----------ใ€argsใ€‘-----------") + for key, value in args_dict.items(): + key_str = f"{key:<{max_key_length}}" + value_str = str(value) + dot_count = total_width - len(key_str) - len(value_str) + dot_fill = "ยท" * dot_count + print(f"{key_str} {dot_fill} {value_str}") + + +def print_on_rank0(message): + if dist.get_rank() == 0: + logger.info(message) + + +def get_last_checkpoint(folder, prefix="epoch"): + content = os.listdir(folder) + _re_checkpoint = re.compile(r"^" + prefix + r"_(\d+)$") + checkpoints = [ + path + for path in content + if _re_checkpoint.search(path) is not None + and os.path.isdir(os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join( + folder, + max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])), + ) + + +def generate_draft_model_config( + target_model_path: str, template_config_path: str = None, cache_dir: str = None +): + """ + Auto-generate draft model config based on target model parameters aligned with template config + + Args: + target_model_path (str): Path to the target model + template_config_path (str, optional): Template config file path, defaults to llama3-8B-eagle3.json + cache_dir (str, optional): Cache directory + + Returns: + dict: Generated draft model config dictionary + """ + # Get target model config + target_config = AutoConfig.from_pretrained(target_model_path, cache_dir=cache_dir) + + # If no template specified, use default llama3-8B-eagle3.json + if template_config_path is None: + # Use the script execution directory as base + import sys + + script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) + project_root = os.path.dirname(script_dir) # Go up one level from scripts/ + template_config_path = os.path.join( + project_root, "configs", "llama3-8B-eagle3.json" + ) + + # Read template config + with open(template_config_path, "r") as f: + draft_config = json.load(f) + + # Adjust architecture config based on target model type + if hasattr(target_config, "model_type"): + # Default to llama architecture + draft_config["model_type"] = "llama" + + # Align key parameters + param_mappings = { + "vocab_size": "vocab_size", + "hidden_size": "hidden_size", + "num_attention_heads": "num_attention_heads", + "num_key_value_heads": "num_key_value_heads", + "intermediate_size": "intermediate_size", + "max_position_embeddings": "max_position_embeddings", + "rms_norm_eps": "rms_norm_eps", + "hidden_act": "hidden_act", + "bos_token_id": "bos_token_id", + "eos_token_id": "eos_token_id", + "torch_dtype": "torch_dtype", + } + + # Copy parameters from target model to draft config + for target_param, draft_param in param_mappings.items(): + if hasattr(target_config, target_param): + value = getattr(target_config, target_param) + # Special handling for torch_dtype to make it JSON serializable + if target_param == "torch_dtype" and isinstance(value, torch.dtype): + value = str(value).replace("torch.", "") + draft_config[draft_param] = value + + # Special handling for some parameters + # Ensure num_hidden_layers is always 1 (EAGLE3 feature) + draft_config["num_hidden_layers"] = 1 + + # Keep some fixed draft model specific parameters + draft_config["tie_word_embeddings"] = False + draft_config["use_cache"] = True + + # If template doesn't have draft_vocab_size, set default + if "draft_vocab_size" not in draft_config: + draft_config["draft_vocab_size"] = 32000 # Default value + + return draft_config + + +def save_draft_model_config(config_dict: dict, output_path: str): + """ + Save draft model config to file + + Args: + config_dict (dict): Config dictionary + output_path (str): Output file path + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(config_dict, f, indent=2, ensure_ascii=False) + + print(f"Draft model config saved to: {output_path}") + + +def create_draft_config_from_target( + target_model_path: str, + output_dir: str = None, + template_config_path: str = None, + cache_dir: str = None, +): + """ + Convenient function to create draft model config file from target model + + Args: + target_model_path (str): Target model path + output_dir (str, optional): Output directory, defaults to configs folder in current directory + template_config_path (str, optional): Template config path + cache_dir (str, optional): Cache directory + + Returns: + str: Generated config file path + """ + # Generate config + rank = dist.get_rank() + + if rank == 0: + print_with_rank( + "No draft model config provided, auto-generating from target model..." + ) + config_dict = generate_draft_model_config( + target_model_path, template_config_path, cache_dir + ) + dist.barrier() + + # Determine output path + if output_dir is None: + # Use the script execution directory as base + import sys + + script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) + project_root = os.path.dirname(script_dir) # Go up one level from scripts/ + output_dir = os.path.join(project_root, "configs") + + # Extract model name from model path + model_name = target_model_path.split("/")[-1].lower() + output_filename = f"{model_name}-eagle3-auto.json" + output_path = os.path.join(output_dir, output_filename) + + # Save config + if rank == 0: + save_draft_model_config(config_dict, output_path) + print_with_rank(f"Auto-generated draft model config saved to: {output_path}") + dist.barrier() + + return output_path + + +def get_full_optimizer_state(optimizer_state_dict: dict): + """ + Convert optimizer state dict with DTensor to full tensors for saving + + Args: + optimizer_state_dict (dict): Optimizer state dict possibly containing DTensors + Returns: + dict: Optimizer state dict with full tensors + """ + full_optimizer_state_dict = { + k: v for k, v in optimizer_state_dict.items() if k != "state" + } + if "state" in optimizer_state_dict: + full_optimizer_state_dict["state"] = { + param_id: { + state_key: ( + state_tensor.full_tensor() + if isinstance(state_tensor, torch.distributed.tensor.DTensor) + else state_tensor + ) + for state_key, state_tensor in param_state.items() + } + for param_id, param_state in optimizer_state_dict["state"].items() + } + return full_optimizer_state_dict + + +def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh): + """ + Shards the optimizer state tensors of a BF16Optimizer instance using DTensor. + + Args: + bf16_optimizer (BF16Optimizer): An instance of BF16Optimizer, which contains + the actual optimizer (e.g., torch.optim.Adam) as its `.optimizer` attribute. + """ + + optim = bf16_optimizer.optimizer + + for group in optim.param_groups: + for p in group["params"]: + if not isinstance(p, DTensor): + continue + + state = optim.state.get(p, None) + if state is None: + continue + + mesh = device_mesh + placements = (Shard(dim=0),) + + for k, v in list(state.items()): + if k == "step": + continue + + if isinstance(v, DTensor): + continue + + if not isinstance(v, torch.Tensor): + continue + + state[k] = distribute_tensor( + v.to(p.device), device_mesh=mesh, placements=placements + ) + + +def safe_conversations_generator(file_path): + """ + Generator that: + 1. Extracts the 'conversations' field. + 2. Preserves all original fields within each message. + 3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility). + """ + with open(file_path, "r", encoding="utf-8") as f: + for i, line in enumerate(f): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + raw_convs = row.get("conversations", []) + + # 1. Ensure 'conversations' is a list + if not isinstance(raw_convs, list): + # If it's None or some unexpected type, treat as empty or skip + if raw_convs is None: + raw_convs = [] + else: + # Edge case: 'conversations' is a plain string or non-iterableโ€”skip this line + logger.warning( + f"Line {i + 1}: 'conversations' is not a list. Please check!" + ) + continue + + cleaned_convs = [] + for msg in raw_convs: + # 2. Ensure each item in the list is a dictionary + if not isinstance(msg, dict): + # Skip if an element is not a dict (e.g., malformed like ["user", "hi"]) + continue + + # 3. [Core logic] Iterate over all fields in the message (role, content, tools, etc.) + new_msg = {} + for k, v in msg.items(): + # If the value is a list or dict, serialize it to a JSON string + # This ensures Arrow treats the column as string type instead of list/struct + if isinstance(v, (list, dict)): + new_msg[k] = json.dumps(v, ensure_ascii=False) + else: + # Keep primitive types (str, int, float, bool, None) unchanged + new_msg[k] = v + + cleaned_convs.append(new_msg) + + # Yield only the processed 'conversations' + yield {"conversations": cleaned_convs} + + except Exception as e: + logger.warning(f"Skipping line {i + 1}: {e}") + continue diff --git a/SpecForge-ext/test_connection.py b/SpecForge-ext/test_connection.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0bfbf426c1160da21a1abdc07529ccd8c114be --- /dev/null +++ b/SpecForge-ext/test_connection.py @@ -0,0 +1,47 @@ +import os +import sys +sys.path.insert(0, '/workspace/sglang/python') + +from sglang.utils import http_request + +# ่ฎพ็ฝฎ็Žฏๅขƒๅ˜้‡ +os.environ['NO_PROXY'] = 'localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16' +os.environ['no_proxy'] = 'localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16' + +print("Testing connection to http://10.10.101.31:30000/get_model_info") +print(f"NO_PROXY: {os.environ.get('NO_PROXY')}") + +# Debug: test the pattern matching +from urllib.parse import urlparse +url = "http://10.10.101.31:30000/get_model_info" +parsed = urlparse(url) +hostname = parsed.hostname +print(f"Hostname: {hostname}") + +# Test pattern matching +no_proxy = os.environ.get('NO_PROXY', '') +for pattern in no_proxy.split(','): + pattern = pattern.strip() + print(f"Testing pattern: {pattern}") + if '/' in pattern: + network_parts = pattern.split('/')[0].split('.') + hostname_parts = hostname.split('.') + cidr = int(pattern.split('/')[1]) + octets_to_check = (cidr + 7) // 8 + print(f" Network parts: {network_parts[:octets_to_check]}") + print(f" Hostname parts: {hostname_parts[:octets_to_check]}") + if hostname_parts[:octets_to_check] == network_parts[:octets_to_check]: + print(f" MATCH!") + +print("\nActual request:") +try: + res = http_request("http://10.10.101.31:30000/get_model_info") + print(f"Status: {res.status_code}") + if res.status_code == 200: + print(f"Response: {res.json()}") + else: + print(f"Error: {res.text}") +except Exception as e: + print(f"Exception: {e}") + import traceback + traceback.print_exc() diff --git a/SpecForge-ext/tests/__init__.py b/SpecForge-ext/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SpecForge-ext/tests/utils.py b/SpecForge-ext/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1cc20609907eeacf62c7e76be4bee5e1caf1ea2 --- /dev/null +++ b/SpecForge-ext/tests/utils.py @@ -0,0 +1,107 @@ +import os +import socket +import subprocess +import time + +import requests +from sglang.utils import print_highlight + + +def is_port_in_use(port: int) -> bool: + """Check if a port is in use""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return False + except OSError: + return True + + +def get_available_port(): + # get a random available port + # and try to find a port that is not in use + for port in range(10000, 65535): + if not is_port_in_use(port): + return port + raise RuntimeError("No available port found") + + +def execute_shell_command( + command: str, disable_proxy: bool = False, enable_hf_mirror: bool = False +): + """ + Execute a shell command and return its process handle. + """ + command = command.replace("\\\n", " ").replace("\\", " ") + parts = command.split() + env = os.environ.copy() + + if disable_proxy: + env.pop("http_proxy", None) + env.pop("https_proxy", None) + env.pop("no_proxy", None) + env.pop("HTTP_PROXY", None) + env.pop("HTTPS_PROXY", None) + env.pop("NO_PROXY", None) + + if enable_hf_mirror: + env["HF_ENDPOINT"] = "https://hf-mirror.com" + return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT, env=env) + + +def wait_for_server( + base_url: str, timeout: int = None, disable_proxy: bool = False +) -> None: + """Wait for the server to be ready by polling the /v1/models endpoint. + + Args: + base_url: The base URL of the server + timeout: Maximum time to wait in seconds. None means wait forever. + """ + start_time = time.perf_counter() + + if disable_proxy: + http_proxy = os.environ.pop("http_proxy", None) + https_proxy = os.environ.pop("https_proxy", None) + no_proxy = os.environ.pop("no_proxy", None) + http_proxy_capitalized = os.environ.pop("HTTP_PROXY", None) + https_proxy_capitalized = os.environ.pop("HTTPS_PROXY", None) + no_proxy_capitalized = os.environ.pop("NO_PROXY", None) + + while True: + try: + response = requests.get( + f"{base_url}/v1/models", + headers={"Authorization": "Bearer None"}, + ) + if response.status_code == 200: + time.sleep(5) + print_highlight( + """\n + NOTE: Typically, the server runs in a separate terminal. + In this notebook, we run the server and notebook code together, so their outputs are combined. + To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue. + To reduce the log length, we set the log level to warning for the server, the default log level is info. + We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance. + """ + ) + break + + if timeout and time.perf_counter() - start_time > timeout: + raise TimeoutError("Server did not become ready within timeout period") + except requests.exceptions.RequestException: + time.sleep(1) + + if disable_proxy: + if http_proxy: + os.environ["http_proxy"] = http_proxy + if https_proxy: + os.environ["https_proxy"] = https_proxy + if no_proxy: + os.environ["no_proxy"] = no_proxy + if http_proxy_capitalized: + os.environ["HTTP_PROXY"] = http_proxy_capitalized + if https_proxy_capitalized: + os.environ["HTTPS_PROXY"] = https_proxy_capitalized + if no_proxy_capitalized: + os.environ["NO_PROXY"] = no_proxy_capitalized diff --git a/SpecForge-ext/training.log b/SpecForge-ext/training.log new file mode 100644 index 0000000000000000000000000000000000000000..00fb9bf2f9d65353fb39e9f811a3b9f1249501c4 --- /dev/null +++ b/SpecForge-ext/training.log @@ -0,0 +1,263 @@ +nohup: ignoring input +bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by bash) +W0211 11:26:05.180000 2473 site-packages/torch/distributed/run.py:803] +W0211 11:26:05.180000 2473 site-packages/torch/distributed/run.py:803] ***************************************** +W0211 11:26:05.180000 2473 site-packages/torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0211 11:26:05.180000 2473 site-packages/torch/distributed/run.py:803] ***************************************** +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set draft model tie_word_embeddings to False +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set draft model tie_word_embeddings to False +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set draft model tie_word_embeddings to False +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set draft model tie_word_embeddings to False +/workspace/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set draft model tie_word_embeddings to False +Set draft model tie_word_embeddings to False +Set draft model tie_word_embeddings to False + + -----------ใ€argsใ€‘----------- +target_model_path /workspace/Qwen3-8B +trust_remote_code ยทยทยทยทยทยทยท False +draft_model_config /workspace/hanrui/SpecForge-ext/configs/qwen3-8b-qwen3eagle-5layer.json +embedding_key model.embed_tokens.weight +lm_head_key lm_head.weight +is_vlm ยทยทยทยทยทยทยท False +target_model_backend ยทยทยทยทยทยท sglang +train_data_path /workspace/hanrui/qwen3-8b_dflash_regen/sharegpt_train_regenerated.jsonl +train_hidden_states_path ยทยทยทยทยทยทยทยท None +eval_hidden_states_path ยทยทยทยทยทยทยทยท None +eval_data_path ยทยทยทยทยทยทยทยท None +chat_template ยทยทยทยทยทยทยทยท qwen +is_preformatted ยทยทยทยทยทยทยท False +train_only_last_turn ยทยทยทยทยทยทยท False +build_dataset_num_proc ยทยทยทยทยทยทยทยทยทยทยท 8 +dataloader_num_workers ยทยทยทยทยทยทยทยทยทยทยท 4 +num_epochs ยทยทยทยทยทยทยทยทยทยท 10 +max_num_steps ยทยทยทยทยทยทยทยท None +batch_size ยทยทยทยทยทยทยทยทยทยทยท 8 +learning_rate ยทยทยทยทยทยท 0.0001 +max_length ยทยทยทยทยทยทยทยท 2048 +warmup_ratio ยทยทยทยทยทยทยท 0.015 +total_steps ยทยทยทยทยทยทยทยท None +max_grad_norm ยทยทยทยทยทยทยทยทยท 0.5 +ttt_length ยทยทยทยทยทยทยทยทยทยทยท 7 +resume ยทยทยทยทยทยทยท False +ckpt_dir ยทยทยทยทยทยทยทยท None +eval_interval ยทยทยทยทยทยทยทยท 5000 +save_interval ยทยทยทยทยทยทยทยท 5000 +log_interval ยทยทยทยทยทยทยทยทยท 100 +seed ยทยทยทยทยทยทยทยทยทยทยท 0 +draft_accumulation_steps ยทยทยทยทยทยทยทยทยทยทยท 1 +tp_size ยทยทยทยทยทยทยทยทยทยทยท 1 +sp_ulysses_size ยทยทยทยทยทยทยทยทยทยทยท 1 +sp_ring_size ยทยทยทยทยทยทยทยทยทยทยท 1 +attention_backend flex_attention +cache_key ยทยทยทยทยทยทยทยท None +cache_dir /workspace/hanrui/SpecForge-ext/cache +output_dir /workspace/hanrui/SpecForge-ext/outputs/qwen3-8b-qwen3eagle-5layer +verbose ยทยทยทยทยทยทยท False +dist_timeout ยทยทยทยทยทยทยทยทยทยท 20 +model_download_dir ยทยทยทยทยทยทยทยท None +min_pixels ยทยทยทยทยทยทยท 50176 +max_pixels ยทยทยทยทยทยท 802816 +profile ยทยทยทยทยทยทยท False +profile_start_step ยทยทยทยทยทยทยทยทยทยท 30 +profile_num_steps ยทยทยทยทยทยทยทยทยทยทยท 4 +profile_record_shapes ยทยทยทยทยทยทยท False +sglang_attention_backend ยทยท flashinfer +sglang_mem_fraction_static ยทยทยทยทยทยทยทยทยท 0.4 +sglang_context_length ยทยทยทยทยทยทยทยท None +sglang_enable_nccl_nvls ยทยทยทยทยทยทยท False +sglang_enable_symm_mem ยทยทยทยทยทยทยท False +sglang_enable_torch_compile ยทยทยทยทยทยทยท False +sglang_enable_dp_attention ยทยทยทยทยทยทยท False +sglang_enable_dp_lm_head ยทยทยทยทยทยทยท False +sglang_enable_piecewise_cuda_graph ยทยทยทยทยทยทยท False +sglang_piecewise_cuda_graph_max_tokens ยทยทยทยทยทยทยทยท 4096 +sglang_piecewise_cuda_graph_tokens ยทยทยทยทยทยทยทยท None +sglang_ep_size ยทยทยทยทยทยทยทยทยทยทยท 1 +report_to ยทยทยทยทยทยทยทยท none +wandb_project ยทยทยทยทยทยทยทยท None +wandb_name ยทยทยทยทยทยทยทยท None +wandb_key ยทยทยทยทยทยทยทยท None +swanlab_project ยทยทยทยทยทยทยทยท None +swanlab_name ยทยทยทยทยทยทยทยท None +swanlab_key ยทยทยทยทยทยทยทยท None +mlflow_tracking_uri ยทยทยทยทยทยทยทยท None +mlflow_experiment_name ยทยทยทยทยทยทยทยท None +mlflow_run_name ยทยทยทยทยทยทยทยท None +dp_size ยทยทยทยทยทยทยทยทยทยทยท 8 +target_batch_size ยทยทยทยทยทยทยทยทยทยทยท 8 +Set draft model tie_word_embeddings to False +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +/bin/bash: /workspace/specforge/lib/libtinfo.so.6: no version information available (required by /bin/bash) +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined +WARNING:sglang.srt.models.registry:Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined + Loading safetensors checkpoint shards: 0% Completed | 0/5 [00:00 + +
+ DFlash Architecture +
+ +https://github.com/user-attachments/assets/5b29cabb-eb95-44c9-8ffe-367c0758de8c + +
+ +## ๐Ÿ“ฆ Model Support Plan + +### โœ… Supported +- **openai/gpt-oss-20b**: https://huggingface.co/z-lab/gpt-oss-20b-DFlash +- **Qwen3-4B**: https://huggingface.co/z-lab/Qwen3-4B-DFlash-b16 +- **Qwen3-8B**: https://huggingface.co/z-lab/Qwen3-8B-DFlash-b16 +- **Qwen3-Coder-30B-A3B**: https://huggingface.co/z-lab/Qwen3-Coder-30B-A3B-DFlash +- **Llama-3.1-8B-Instruct**: https://huggingface.co/z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat + +### ๐Ÿšง Coming Soon +- **Qwen/Qwen3-Coder-Next** (Very soon) +- **openai/gpt-oss-120b** +- **zai-org/GLM-4.7** +- **zai-org/GLM-4.7-Flash** + +> ๐Ÿ’ก Feel free to open a GitHub issue if youโ€™d like to request support for additional models! +> We will also open-source the training recipe soon, so you can train your own DFlash draft model to accelerate any LLM. + +
+ +## ๐Ÿš€ Quick Start + +### Installation +```bash +conda create -n dflash python=3.11 +conda activate dflash + +git clone https://github.com/z-lab/dflash.git +cd dflash + +pip install uv +uv pip install -r requirements.txt + +# Optionally install flash-attn. +# If unavailable, evaluation falls back to torch.sdpa in the Transformers backend. +# The measured speedup will be slower, but the acceptance length remains comparable. + +# uv pip install flash-attn --no-build-isolation +``` + +### SGLang + +```bash +export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 + +python -m sglang.launch_server \ + --model-path Qwen/Qwen3-Coder-30B-A3B-Instruct \ + --speculative-algorithm DFLASH \ + --speculative-draft-model-path z-lab/Qwen3-Coder-30B-A3B-DFlash \ + --tp-size 1 \ + --dtype bfloat16 \ + --attention-backend fa3 \ + --mem-fraction-static 0.75 \ + --trust-remote-code +``` + +### Transformers + +```python +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +model = AutoModel.from_pretrained( + "z-lab/Qwen3-8B-DFlash-b16", + trust_remote_code=True, + dtype="auto", + device_map="cuda:0" +).eval() + +target = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + dtype="auto", + device_map="cuda:0" +).eval() + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") +prompt = "How many positive whole-number divisors does 196 have?" +messages = [ + {"role": "user", "content": prompt} +] +# Note: this draft model is used for thinking mode disabled +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +generate_ids = model.spec_generate( + input_ids=model_inputs["input_ids"], + max_new_tokens=2048, + temperature=0.0, + target=target, + stop_token_ids=[tokenizer.eos_token_id] +) + +print(tokenizer.decode(generate_ids[0], skip_special_tokens=False)) +``` + +## ๐Ÿ“Š Evaluation +We provide scripts to reproduce the speedup and acceptance length metrics in the paper. The reported results were tested on NVIDIA H200 or B200 GPUs. + +To run benchmark on Transformers backend: +```bash +bash run_benchmark.sh +``` + +To run benchmark on SGLang: +```bash +export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 + +python benchmark_sglang.py \ + --target-model Qwen/Qwen3-8B \ + --draft-model z-lab/Qwen3-8B-DFlash-b16 \ + --concurrencies 1,4,8,16,32 \ + --dataset-name math500 \ + --attention-backends fa3,flashinfer \ + --tp-size 1 \ + --output-md sglang_results.md +``` + +
+ +
+ +## **Acknowledgement** + +Huge thanks to [@dcw02](https://github.com/dcw02), [@gongy](https://github.com/gongy), and the other folks at [@modal-labs](https://github.com/modal-labs) for the fast, high-quality support in bringing DFlash into SGLangโ€”making it possible to truly accelerate LLM serving in real-world deployments. + +## **Citation** +If you find DFlash useful for your research or applications, please cite our project. + +```bibtex +@article{chen2026dflash, + title = {{DFlash: Block Diffusion for Flash Speculative Decoding}}, + author = {Chen, Jian and Liang, Yesheng and Liu, Zhijian}, + journal = {arXiv preprint arXiv:2602.06036}, + year = {2026} +} +``` diff --git a/dflash/assets/speedup.png b/dflash/assets/speedup.png new file mode 100644 index 0000000000000000000000000000000000000000..fe5f06fa1109935c538fccc33ce49a7086595ece Binary files /dev/null and b/dflash/assets/speedup.png differ diff --git a/dflash/benchmark.py b/dflash/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..24e39b0ce16100991ec780b2bf53989152fb5690 --- /dev/null +++ b/dflash/benchmark.py @@ -0,0 +1,231 @@ +import argparse +import time +import random +from itertools import chain +from types import SimpleNamespace +from loguru import logger +import numpy as np +import torch +from rich import print +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache +from model import DFlashDraftModel, sample, load_and_process_dataset, extract_context_feature +import distributed as dist + +def cuda_time() -> float: + torch.cuda.synchronize() + return time.perf_counter() + +@torch.inference_mode() +def dflash_generate( + model: DFlashDraftModel, + target: AutoModelForCausalLM, + input_ids: torch.Tensor, + mask_token_id: int, + max_new_tokens: int, + block_size: int, + stop_token_ids: list[int], + temperature: float = 0.0, +) -> SimpleNamespace: + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + output_ids = torch.full( + (1, max_length + block_size), + mask_token_id, + dtype=torch.long, + device=model.device, + ) + position_ids = torch.arange(output_ids.shape[1], device=model.device).unsqueeze(0) + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + prefill_start = cuda_time() + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True if block_size > 1 else False, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens:num_input_tokens+1] = sample(output.logits, temperature) + if block_size > 1: + target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids) + + time_to_first_token = cuda_time() - prefill_start + + # Decode stage + decode_start = cuda_time() + start = input_ids.shape[1] + acceptance_lengths = [] + draft_prefill = True + + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + if block_size > 1: + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head(model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[:, past_key_values_draft.get_seq_length(): start + block_size], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size+1:, :]) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample(draft_logits) + if draft_prefill: + draft_prefill = False + decode_start = cuda_time() + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True if block_size > 1 else False, + ) + + posterior = sample(output.logits, temperature) + acceptance_length = (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item() + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] + output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length] + + acceptance_lengths.append(acceptance_length+1) + start += acceptance_length + 1 + past_key_values_target.crop(start) + if block_size > 1: + target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids)[:, :acceptance_length + 1, :] + + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] for stop_token_id in stop_token_ids + ): + break + + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != mask_token_id] + if stop_token_ids is not None: + stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) + stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids).nonzero(as_tuple=True)[0] + if stop_token_indices.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1] + + num_output_tokens = output_ids.shape[1] - num_input_tokens + total_decode_time = cuda_time() - decode_start + time_per_output_token = total_decode_time / num_output_tokens + + return SimpleNamespace( + output_ids=output_ids, + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + time_to_first_token=time_to_first_token, + time_per_output_token=time_per_output_token, + acceptance_lengths=acceptance_lengths, + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-name-or-path", type=str, required=True) + parser.add_argument("--draft-name-or-path", type=str, required=True) + parser.add_argument("--block-size", type=int, default=None) + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--max-samples", type=int, default=None) + parser.add_argument("--max-new-tokens", type=int, default=16384) + parser.add_argument("--temperature", type=float, default=0.0) + args = parser.parse_args() + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + dist.init() + torch.cuda.set_device(dist.local_rank()) + device = torch.device(f"cuda:{dist.local_rank()}") + + def has_flash_attn(): + try: + import flash_attn + return True + except ImportError: + logger.warning("flash_attn is not installed. Falling back to torch.sdpa. The speedup will be lower.") + return False + + installed_flash_attn = has_flash_attn() + + target = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + attn_implementation="flash_attention_2" if installed_flash_attn else "sdpa", + dtype=torch.bfloat16, + ).to(device).eval() + + draft_model = DFlashDraftModel.from_pretrained( + args.draft_name_or_path, + attn_implementation="flash_attention_2" if installed_flash_attn else "sdpa", + dtype=torch.bfloat16, + ).to(device).eval() + + block_size = args.block_size if args.block_size is not None else draft_model.block_size + + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + dataset = load_and_process_dataset(args.dataset) + + if args.max_samples is not None and len(dataset) > args.max_samples: + dataset = dataset.shuffle(seed=0).select(range(args.max_samples)) + + responses = [] + indices = range(dist.rank(), len(dataset), dist.size()) + for idx in tqdm(indices, disable=not dist.is_main()): + instance = dataset[idx] + messages = [] + for turn_index, user_content in enumerate(instance["turns"]): + messages.append({"role": "user", "content": user_content}) + input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) + input_ids = tokenizer.encode(input_text, return_tensors="pt").to(target.device) + + response = {} + for bs in [1, block_size]: + response[bs] = dflash_generate( + model=draft_model, + target=target, + input_ids=input_ids, + mask_token_id=draft_model.mask_token_id, + max_new_tokens=args.max_new_tokens, + block_size=bs, + stop_token_ids=[tokenizer.eos_token_id], + temperature=args.temperature, + ) + + spec_response = response[block_size] + generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:] + output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) + messages.append({"role": "assistant", "content": output_text}) + responses.append(response) + + if dist.size() > 1: + responses = dist.gather(responses, dst=0) + if not dist.is_main(): + return + responses = list(chain(*responses)) + + t1 = np.mean([r[1].time_per_output_token for r in responses]) + tb = np.mean([r[block_size].time_per_output_token for r in responses]) + print(f"Decoding speedup: {t1 / tb:.2f}") + + tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses]) + print(f"Average Acceptance length: {tau:.2f}") + + acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses])) + histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)] + print(f"Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dflash/benchmark_sglang.py b/dflash/benchmark_sglang.py new file mode 100644 index 0000000000000000000000000000000000000000..88c10e3e600622e5ab5fc3b43367320150c7db56 --- /dev/null +++ b/dflash/benchmark_sglang.py @@ -0,0 +1,576 @@ +from __future__ import annotations + +import argparse +import time +import statistics +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Optional + +import requests +import torch +from transformers import AutoTokenizer +from model import load_and_process_dataset + +from sglang.srt.environ import envs +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + find_available_port, + popen_launch_server, +) + +def _is_blackwell() -> bool: + if envs.IS_BLACKWELL.get(): + return True + return get_device_sm() >= 100 + + +def _flush_cache(base_url: str) -> None: + resp = requests.get(base_url + "/flush_cache", timeout=60) + resp.raise_for_status() + + +def _send_generate( + base_url: str, + prompt: str, + *, + max_new_tokens: int, + stop: list[str], + timeout_s: int, +) -> dict: + sampling_params: dict = { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": int(max_new_tokens), + } + if stop: + sampling_params["stop"] = stop + resp = requests.post( + base_url + "/generate", + json={ + "text": prompt, + "sampling_params": sampling_params, + }, + timeout=int(timeout_s), + ) + resp.raise_for_status() + return resp.json() + + +def _send_generate_batch( + base_url: str, + prompts: list[str], + *, + max_new_tokens: int, + stop: list[str], + timeout_s: int, +) -> list[dict]: + if not prompts: + return [] + sampling_params: dict = { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": int(max_new_tokens), + } + if stop: + sampling_params["stop"] = stop + resp = requests.post( + base_url + "/generate", + json={ + "text": prompts, + "sampling_params": sampling_params, + }, + timeout=int(timeout_s), + ) + resp.raise_for_status() + out = resp.json() + if not isinstance(out, list): + raise RuntimeError( + "Expected a list response for batched /generate, but got " + f"type={type(out).__name__}." + ) + return out + + +@dataclass(frozen=True) +class BenchMetrics: + latency_s: float + output_tokens: int + output_toks_per_s: float + spec_accept_length: Optional[float] + spec_verify_ct_sum: int + + +def _run_bench_requests( + base_url: str, + *, + prompts: list[str], + max_new_tokens: int, + concurrency: int, + batch_requests: bool, + stop: list[str], + timeout_s: int, + expect_dflash: bool, +) -> BenchMetrics: + # Drop the first batch from metrics to exclude one-time JIT/cuda-graph overhead + bs = max(int(concurrency), 1) + if len(prompts) > bs: + warmup_prompts = prompts[:bs] + if batch_requests: + _send_generate_batch( + base_url, + warmup_prompts, + max_new_tokens=max_new_tokens, + stop=stop, + timeout_s=timeout_s, + ) + else: + with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: + futures = [ + pool.submit( + _send_generate, + base_url, + prompt, + max_new_tokens=max_new_tokens, + stop=stop, + timeout_s=timeout_s, + ) + for prompt in warmup_prompts + ] + for fut in as_completed(futures): + fut.result() + + prompts = prompts[bs:] + + start = time.perf_counter() + total_tokens = 0 + spec_verify_ct_sum = 0 + spec_accept_lengths: list[float] = [] + + if batch_requests: + bs = max(int(concurrency), 1) + for start_idx in range(0, len(prompts), bs): + chunk_prompts = prompts[start_idx : start_idx + bs] + outs = _send_generate_batch( + base_url, + chunk_prompts, + max_new_tokens=max_new_tokens, + stop=stop, + timeout_s=timeout_s, + ) + if len(outs) != len(chunk_prompts): + raise RuntimeError( + "Batched /generate output length mismatch: " + f"got {len(outs)} outputs for {len(chunk_prompts)} prompts." + ) + + for j, out in enumerate(outs): + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass + else: + with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: + futures = { + pool.submit( + _send_generate, + base_url, + prompt, + max_new_tokens=max_new_tokens, + stop=stop, + timeout_s=timeout_s, + ): i + for i, prompt in enumerate(prompts) + } + for fut in as_completed(futures): + out = fut.result() + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass + + latency = time.perf_counter() - start + toks_per_s = total_tokens / max(latency, 1e-6) + + if expect_dflash and spec_verify_ct_sum <= 0: + raise RuntimeError( + "DFLASH sanity check failed: did not observe any `spec_verify_ct` in responses " + "(DFLASH may not have been enabled)." + ) + + spec_accept_length = ( + float(statistics.mean(spec_accept_lengths)) if spec_accept_lengths else None + ) + + return BenchMetrics( + latency_s=float(latency), + output_tokens=int(total_tokens), + output_toks_per_s=float(toks_per_s), + spec_accept_length=spec_accept_length, + spec_verify_ct_sum=int(spec_verify_ct_sum), + ) + + +def _format_table( + *, + concurrencies: list[int], + values: dict[int, Optional[float]], + float_fmt: str, +) -> str: + header = ["conc"] + [str(c) for c in concurrencies] + lines = [ + "| " + " | ".join(header) + " |", + "| " + " | ".join(["---"] * len(header)) + " |", + ] + row = ["value"] + for c in concurrencies: + v = values.get(c, None) + row.append("N/A" if v is None else format(v, float_fmt)) + lines.append("| " + " | ".join(row) + " |") + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--output-md", + type=str, + default=None, + help="Write a markdown report to this file (disabled by default).", + ) + parser.add_argument("--dataset-name", type=str, default="gsm8k") + parser.add_argument("--target-model", type=str, default="Qwen/Qwen3-8B") + parser.add_argument("--draft-model", type=str, default="z-lab/Qwen3-8B-DFlash-b16") + parser.add_argument( + "--skip-baseline", + action="store_true", + help="Skip running the baseline (target-only) sweep; only run DFLASH and report N/A for baseline/speedup.", + ) + parser.add_argument( + "--batch-requests", + action="store_true", + help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.", + ) + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--timeout-s", type=int, default=3600) + parser.add_argument("--mem-fraction-static", type=float, default=0.75) + parser.add_argument("--disable-radix-cache", action="store_true") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--max-running-requests", type=int, default=64) + parser.add_argument( + "--tp-size", + type=int, + default=1, + help="Tensor parallel size (single value, no sweep).", + ) + parser.add_argument( + "--concurrencies", + type=str, + default="1,2,4,8,16,32", + help="Comma-separated list of client concurrency levels.", + ) + parser.add_argument( + "--questions-per-concurrency-base", + type=int, + default=128, + help="num_questions = base * concurrency (default matches the sweep plan).", + ) + parser.add_argument( + "--max-questions-per-config", + type=int, + default=1024, + help="Cap num_questions per (tp, concurrency) run (default: 1024).", + ) + parser.add_argument( + "--attention-backends", + type=str, + default="flashinfer,fa3,fa4", + help="Comma-separated list. Will auto-skip fa3 unless SM90 (Hopper), and fa4 unless SM100+ (Blackwell).", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this sweep.") + + concurrencies = [int(x) for x in args.concurrencies.split(",") if x.strip()] + concurrencies = [c for c in concurrencies if c >= 1] + if not concurrencies: + raise RuntimeError("No concurrencies specified.") + + num_questions_by_conc = { + c: min(int(args.questions_per_concurrency_base) * int(c), int(args.max_questions_per_config)) + for c in concurrencies + } + max_questions = max(num_questions_by_conc.values()) + max_concurrency = max(concurrencies) + + attention_backends = [s.strip() for s in args.attention_backends.split(",") if s.strip()] + is_blackwell = _is_blackwell() + device_sm = get_device_sm() + if device_sm != 90: + attention_backends = [b for b in attention_backends if b != "fa3"] + if device_sm < 100: + attention_backends = [b for b in attention_backends if b != "fa4"] + attention_backends = attention_backends or ["flashinfer"] + + # --- Load Data using the new function --- + print(f"Loading dataset: {args.dataset_name}...") + dataset = load_and_process_dataset(args.dataset_name) + required_questions = max_questions + max_concurrency + + if len(dataset) < required_questions: + print(f"Warning: Dataset has {len(dataset)} items, but need up to {required_questions}. Reusing items.") + + tokenizer = AutoTokenizer.from_pretrained(args.target_model) + + prompts: list[str] = [] + # Build prompts list + for i in range(max(len(dataset), required_questions)): + item = dataset[i % len(dataset)] + user_content = item["turns"][0] # Extract the formatted turn + + # Apply chat template + prompt_text = tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + prompts.append(prompt_text) + if len(prompts) >= required_questions: + break + + # Results indexed by (backend, concurrency) for baseline + dflash. + # Removed TP dimension from keys since we aren't sweeping it. + baseline_toks: dict[tuple[str, int], Optional[float]] = {} + dflash_toks: dict[tuple[str, int], Optional[float]] = {} + dflash_accept_len: dict[tuple[str, int], Optional[float]] = {} + + tp = args.tp_size # Fixed TP size + + for backend in attention_backends: + port_base = find_available_port(20000) + + common_server_args: list[str] = [ + "--trust-remote-code", + "--attention-backend", + backend, + "--tp-size", + str(tp), + "--dtype", + str(args.dtype), + "--mem-fraction-static", + str(args.mem_fraction_static), + "--max-running-requests", + str(args.max_running_requests), + ] + common_server_args.extend( + ["--cuda-graph-bs", *[str(i) for i in range(1, 33)], "--cuda-graph-max-bs", "32"] + ) + if args.disable_radix_cache: + common_server_args.append("--disable-radix-cache") + + if not args.skip_baseline: + print(f"\n=== backend={backend} tp={tp} (baseline) ===") + baseline_port = port_base + baseline_url = f"http://127.0.0.1:{baseline_port}" + baseline_proc = popen_launch_server( + args.target_model, + baseline_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=common_server_args, + ) + try: + # Warm up. + _send_generate( + baseline_url, + "Hello", + max_new_tokens=8, + stop=[], + timeout_s=min(int(args.timeout_s), 300), + ) + + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(baseline_url) + print( + f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics." + ) + metrics = _run_bench_requests( + baseline_url, + prompts=prompts[: n + conc], + max_new_tokens=int(args.max_new_tokens), + concurrency=int(conc), + batch_requests=bool(args.batch_requests), + stop=[], + timeout_s=int(args.timeout_s), + expect_dflash=False, + ) + baseline_toks[(backend, conc)] = metrics.output_toks_per_s + print( + f"[baseline] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + ) + finally: + kill_process_tree(baseline_proc.pid) + try: + baseline_proc.wait(timeout=30) + except Exception: + pass + + print(f"\n=== backend={backend} tp={tp} (DFLASH) ===") + dflash_port = find_available_port(port_base + 1) + dflash_url = f"http://127.0.0.1:{dflash_port}" + dflash_proc = popen_launch_server( + args.target_model, + dflash_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + args.draft_model, + ], + ) + try: + _send_generate( + dflash_url, + "Hello", + max_new_tokens=8, + stop=[], + timeout_s=min(int(args.timeout_s), 300), + ) + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(dflash_url) + print( + f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics." + ) + metrics = _run_bench_requests( + dflash_url, + prompts=prompts[: n + conc], + max_new_tokens=int(args.max_new_tokens), + concurrency=int(conc), + batch_requests=bool(args.batch_requests), + stop=[], + timeout_s=int(args.timeout_s), + expect_dflash=True, + ) + dflash_toks[(backend, conc)] = metrics.output_toks_per_s + dflash_accept_len[(backend, conc)] = metrics.spec_accept_length + print( + f"[DFLASH] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + f"accept_len={metrics.spec_accept_length:.3f} " + f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" + ) + finally: + kill_process_tree(dflash_proc.pid) + try: + dflash_proc.wait(timeout=30) + except Exception: + pass + + # Render markdown. + md_lines: list[str] = [] + md_lines.append("# DFLASH Bench Report") + md_lines.append("") + md_lines.append("## Settings") + md_lines.append(f"- dataset: `{args.dataset_name}`") + md_lines.append(f"- target_model: `{args.target_model}`") + md_lines.append(f"- draft_model: `{args.draft_model}`") + md_lines.append(f"- max_new_tokens: `{args.max_new_tokens}`") + md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`") + md_lines.append(f"- tp_size: `{tp}`") + md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`") + md_lines.append(f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`") + md_lines.append(f"- device_sm: `{device_sm}`") + md_lines.append(f"- is_blackwell: `{is_blackwell}`") + md_lines.append(f"- skip_baseline: `{bool(args.skip_baseline)}`") + md_lines.append("- drop_first_batch: `true`") + md_lines.append("") + + for backend in attention_backends: + md_lines.append(f"## Backend: `{backend}`") + md_lines.append("") + + baseline_values = { + c: baseline_toks.get((backend, c), None) for c in concurrencies + } + dflash_values = { + c: dflash_toks.get((backend, c), None) for c in concurrencies + } + speedup_values: dict[int, Optional[float]] = {} + for c in concurrencies: + b = baseline_values.get(c, None) + d = dflash_values.get(c, None) + speedup_values[c] = None if (b is None or d is None or b <= 0) else (d / b) + + md_lines.append("### Baseline output tok/s") + md_lines.append( + _format_table( + concurrencies=concurrencies, + values=baseline_values, + float_fmt=",.2f", + ) + ) + md_lines.append("") + + md_lines.append("### DFLASH output tok/s") + md_lines.append( + _format_table( + concurrencies=concurrencies, + values=dflash_values, + float_fmt=",.2f", + ) + ) + md_lines.append("") + + md_lines.append("### Speedup (DFLASH / baseline)") + md_lines.append( + _format_table( + concurrencies=concurrencies, + values=speedup_values, + float_fmt=".3f", + ) + ) + md_lines.append("") + + md_lines.append("### DFLASH acceptance length") + md_lines.append( + _format_table( + concurrencies=concurrencies, + values={ + c: dflash_accept_len.get((backend, c), None) + for c in concurrencies + }, + float_fmt=".3f", + ) + ) + md_lines.append("") + + if args.output_md: + with open(args.output_md, "w", encoding="utf-8") as f: + f.write("\n".join(md_lines)) + f.write("\n") + print(f"\nWrote markdown report to: {args.output_md}") + else: + print("\nMarkdown report disabled (pass --output-md to write one).") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dflash/distributed.py b/dflash/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb267e8bb29dc4cea98748982b03c433849ab5b --- /dev/null +++ b/dflash/distributed.py @@ -0,0 +1,71 @@ +import os +import warnings +from typing import Any, List, Optional +from torch import distributed as dist +__all__ = [ + "init", + "is_initialized", + "size", + "rank", + "local_size", + "local_rank", + "is_main", + "barrier", + "gather", + "all_gather", +] +def init() -> None: + if "RANK" not in os.environ: + warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") + return + dist.init_process_group(backend="nccl", init_method="env://") + + +def is_initialized() -> bool: + return dist.is_initialized() + + +def size() -> int: + return int(os.environ.get("WORLD_SIZE", 1)) + + +def rank() -> int: + return int(os.environ.get("RANK", 0)) + + +def local_size() -> int: + return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + + +def local_rank() -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + +def is_main() -> bool: + return rank() == 0 + + +def barrier() -> None: + if not is_initialized(): + return + dist.barrier() + + +def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: + if not is_initialized(): + return [obj] + if is_main(): + objs = [None for _ in range(size())] + dist.gather_object(obj, objs, dst=dst) + return objs + else: + dist.gather_object(obj, dst=dst) + return None + + +def all_gather(obj: Any) -> List[Any]: + if not is_initialized(): + return [obj] + objs = [None for _ in range(size())] + dist.all_gather_object(objs, obj) + return objs diff --git a/dflash/model/__init__.py b/dflash/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7356f655180e8742c2259f841df5f8b0518a495 --- /dev/null +++ b/dflash/model/__init__.py @@ -0,0 +1,2 @@ +from .dflash import DFlashDraftModel +from .utils import extract_context_feature, sample, load_and_process_dataset \ No newline at end of file diff --git a/dflash/model/__pycache__/__init__.cpython-311.pyc b/dflash/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dbf9f54e42d3fa3448f73b208685fa8322eb531 Binary files /dev/null and b/dflash/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/dflash/model/__pycache__/__init__.cpython-313.pyc b/dflash/model/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd74b1255d3e1972c4e23f1b85bab3e4e660b8e Binary files /dev/null and b/dflash/model/__pycache__/__init__.cpython-313.pyc differ diff --git a/dflash/model/__pycache__/dflash.cpython-311.pyc b/dflash/model/__pycache__/dflash.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43b0583fa958332926a3dfa72c0be3a219453cbc Binary files /dev/null and b/dflash/model/__pycache__/dflash.cpython-311.pyc differ diff --git a/dflash/model/__pycache__/dflash.cpython-313.pyc b/dflash/model/__pycache__/dflash.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b36d5029df7a7576bb7f680ade13eac4b6b82ea Binary files /dev/null and b/dflash/model/__pycache__/dflash.cpython-313.pyc differ diff --git a/dflash/model/__pycache__/utils.cpython-311.pyc b/dflash/model/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7e3be332f3d7bd6ddc7ea7e09e786ac3447d244 Binary files /dev/null and b/dflash/model/__pycache__/utils.cpython-311.pyc differ diff --git a/dflash/model/__pycache__/utils.cpython-313.pyc b/dflash/model/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19f1e93199dff332920e332f4b5f2218c95e6bb3 Binary files /dev/null and b/dflash/model/__pycache__/utils.cpython-313.pyc differ diff --git a/dflash/model/dflash.py b/dflash/model/dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cd5d26057e92f4721977cf53195c2d738a6b76 --- /dev/null +++ b/dflash/model/dflash.py @@ -0,0 +1,277 @@ +from typing import Optional, Callable +from typing_extensions import Unpack, Tuple +import torch +from torch import nn +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + Qwen3Config, + Qwen3PreTrainedModel, + Qwen3MLP, + GradientCheckpointingLayer, + FlashAttentionKwargs, + rotate_half, + eager_attention_forward, + ALL_ATTENTION_FUNCTIONS, +) +from transformers import DynamicCache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.cache_utils import Cache +from .utils import build_target_layer_ids, extract_context_feature, sample + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class Qwen3DFlashAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + v = torch.cat([v_ctx, v_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + v = v.transpose(1, 2) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) + attn_fn: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + +class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + target_hidden: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + +class DFlashDraftModel(Qwen3PreTrainedModel): + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlashDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [Qwen3DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.target_layer_ids = self.config.dflash_config.get("target_layer_ids", build_target_layer_ids(config.num_target_layers, config.num_hidden_layers)) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.fc = nn.Linear(len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.mask_token_id = self.config.dflash_config.get("mask_token_id", None) + self.post_init() + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + target_hidden: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + **kwargs, + ) -> CausalLMOutputWithPast: + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + return self.norm(hidden_states) + + @torch.inference_mode() + def spec_generate( + self, + target: nn.Module, + input_ids: torch.LongTensor, + max_new_tokens: int, + stop_token_ids: list[int], + temperature: float, + ): + self.eval() + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + block_size = self.block_size + output_ids = torch.full( + (1, max_length + block_size), + self.mask_token_id, + dtype=torch.long, + device=target.device, + ) + position_ids = torch.arange(output_ids.shape[1], device=target.device).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens:num_input_tokens+1] = sample(output.logits, temperature) + target_hidden = extract_context_feature(output.hidden_states, self.target_layer_ids) + + # Decode stage + acceptance_lengths = [] + start = input_ids.shape[1] + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head(self( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[:, past_key_values_draft.get_seq_length(): start + block_size], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size+1:, :]) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample(draft_logits) + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True, + ) + + posterior = sample(output.logits, temperature) + acceptance_length = (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item() + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] + output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length] + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = extract_context_feature(output.hidden_states, self.target_layer_ids)[:, :acceptance_length + 1, :] + acceptance_lengths.append(acceptance_length+1) + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] for stop_token_id in stop_token_ids + ): + break + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != self.mask_token_id] + if stop_token_ids is not None: + stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) + stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids).nonzero(as_tuple=True)[0] + if stop_token_indices.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1] + + return output_ids diff --git a/dflash/model/utils.py b/dflash/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a05eee433eac10792f8ab6bb75b7d3cf7027ed56 --- /dev/null +++ b/dflash/model/utils.py @@ -0,0 +1,116 @@ +import torch +from typing import Optional +from datasets import load_dataset, Features, Sequence, Value + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [(num_target_layers // 2)] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + return target_layer_ids + +def extract_context_feature( + hidden_states: list[torch.Tensor], + layer_ids: Optional[list[int]], +) -> torch.Tensor: + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + +def load_and_process_dataset(data_name: str): + # Math datasets + if data_name == "gsm8k": + dataset = load_dataset("openai/gsm8k", "main", split="test") + prompt_fmt = "{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "math500": + dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "aime24": + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "aime25": + dataset = load_dataset("MathArena/aime_2025", split="train") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + # Chat datasets + elif data_name == "alpaca": + dataset = load_dataset("tatsu-lab/alpaca", split="train") + dataset = dataset.map(lambda x: {"formatted_input": (f"{x['instruction']}\n\nInput:\n{x['input']}" if x['input'] else x['instruction'])}) + dataset = dataset.map(lambda x: {"turns": [x["formatted_input"]]}) + + elif data_name == "mt-bench": + dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") + dataset = dataset.map(lambda x: {"turns": x["prompt"]}) + + # Coding datasets + elif data_name == "humaneval": + dataset = load_dataset("openai/openai_humaneval", split="test") + prompt_fmt = "Write a solution to the following problem and make sure that it passes the tests:\n```python\n{prompt}\n```" + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "mbpp": + dataset = load_dataset("google-research-datasets/mbpp", "sanitized", split="test") + dataset = dataset.map(lambda x: {"turns": [x["prompt"]]}) + + elif data_name == "lbpp": + LBPP_PY_TEST_URL = "https://huggingface.co/datasets/CohereLabs/lbpp/resolve/main/python/test.parquet" + dataset = load_dataset("parquet", data_files={"test": LBPP_PY_TEST_URL})["test"] + dataset = dataset.map(lambda x: {"turns": [x["instruction"]]}) + + elif data_name == "swe-bench": + dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + prompt_fmt = "Problem Statement:\n{problem_statement}\nPlease fix the issue described above." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "livecodebench": + base = "https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/main/" + allowed_files = ["test.jsonl", "test2.jsonl", "test3.jsonl", "test4.jsonl", "test5.jsonl", "test6.jsonl"] + urls = [base + fn for fn in allowed_files] + dataset = load_dataset("json", data_files={"test": urls})["test"] + def format_lcb(doc): + system_prompt = ( + "You are an expert Python programmer. You will be given a question (problem specification) " + "and will generate a correct Python program that matches the specification and passes all tests. " + "You will NOT return anything except for the program" + ) + question_block = f"### Question:\n{doc['question_content']}" + if doc.get("starter_code"): + format_message = "### Format: Use the following code structure:" + code_block = f"```python\n{doc['starter_code']}\n```" + else: + format_message = "### Format: Write your code in the following format:" + code_block = "```python\n# YOUR CODE HERE\n```" + answer_footer = "### Answer: (use the provided format with backticks)" + return f"{system_prompt}\n\n{question_block}\n\n{format_message}\n{code_block}\n\n{answer_footer}" + target_features = Features({"turns": Sequence(Value("large_string"))}) + dataset = dataset.map( + lambda x: {"turns": [format_lcb(x)]}, + remove_columns=dataset.column_names, + features=target_features + ) + + return dataset \ No newline at end of file diff --git a/dflash/requirements.txt b/dflash/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a995b709522944cd4f9f49ec5ee513bb7840659a --- /dev/null +++ b/dflash/requirements.txt @@ -0,0 +1,5 @@ +git+https://github.com/sgl-project/sglang.git@refs/pull/16818/head#subdirectory=python +accelerate +rich +packaging +ninja \ No newline at end of file diff --git a/dflash/run_benchmark.sh b/dflash/run_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..f11eccc7138635814a392bf60f0362793a9db6fa --- /dev/null +++ b/dflash/run_benchmark.sh @@ -0,0 +1,37 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +mkdir -p logs + +TASKS=( + "gsm8k:128" + "math500:128" + "aime24:30" + "aime25:30" + "humaneval:164" + "mbpp:128" + "livecodebench:128" + "swe-bench:128" + "mt-bench:80" + "alpaca:128" +) + +for task in "${TASKS[@]}"; do + IFS=':' read -r DATASET_NAME MAX_SAMPLES <<< "$task" + + echo "========================================================" + echo "Running Benchmark: $DATASET_NAME with $MAX_SAMPLES samples" + echo "========================================================" + + torchrun \ + --nproc_per_node=8 \ + --master_port=29600 \ + benchmark.py \ + --dataset "$DATASET_NAME" \ + --max-samples "$MAX_SAMPLES" \ + --model-name-or-path Qwen/Qwen3-4B \ + --draft-name-or-path z-lab/Qwen3-4B-DFlash-b16 \ + --max-new-tokens 2048 \ + --temperature 0.0 \ + 2>&1 | tee "logs/${DATASET_NAME}.log" + +done \ No newline at end of file diff --git a/idea1/MANIFEST.in b/idea1/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..7e3c8f05614505dc88691fa12babee86f8d1995e --- /dev/null +++ b/idea1/MANIFEST.in @@ -0,0 +1,2 @@ +include requirements.txt +include version.txt diff --git a/idea1/README.md b/idea1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..141963e25880130165ca60532f9f0b4292adea41 --- /dev/null +++ b/idea1/README.md @@ -0,0 +1,70 @@ +
+logo + +[![documentation](https://img.shields.io/badge/๐Ÿ“–-Documentation-red.svg?style=flat)](https://docs.sglang.ai/SpecForge/) +[![SpecBundle](https://img.shields.io/badge/๐Ÿค—%20SpecBundle-yellow.svg?style=flat)](https://huggingface.co/collections/lmsys/specbundle) +[![DeepWiki](https://img.shields.io/badge/DeepWiki-SpecForge-blue.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACwAAAAyCAYAAAAnWDnqAAAAAXNSR0IArs4c6QAAA05JREFUaEPtmUtyEzEQhtWTQyQLHNak2AB7ZnyXZMEjXMGeK/AIi+QuHrMnbChYY7MIh8g01fJoopFb0uhhEqqcbWTp06/uv1saEDv4O3n3dV60RfP947Mm9/SQc0ICFQgzfc4CYZoTPAswgSJCCUJUnAAoRHOAUOcATwbmVLWdGoH//PB8mnKqScAhsD0kYP3j/Yt5LPQe2KvcXmGvRHcDnpxfL2zOYJ1mFwrryWTz0advv1Ut4CJgf5uhDuDj5eUcAUoahrdY/56ebRWeraTjMt/00Sh3UDtjgHtQNHwcRGOC98BJEAEymycmYcWwOprTgcB6VZ5JK5TAJ+fXGLBm3FDAmn6oPPjR4rKCAoJCal2eAiQp2x0vxTPB3ALO2CRkwmDy5WohzBDwSEFKRwPbknEggCPB/imwrycgxX2NzoMCHhPkDwqYMr9tRcP5qNrMZHkVnOjRMWwLCcr8ohBVb1OMjxLwGCvjTikrsBOiA6fNyCrm8V1rP93iVPpwaE+gO0SsWmPiXB+jikdf6SizrT5qKasx5j8ABbHpFTx+vFXp9EnYQmLx02h1QTTrl6eDqxLnGjporxl3NL3agEvXdT0WmEost648sQOYAeJS9Q7bfUVoMGnjo4AZdUMQku50McDcMWcBPvr0SzbTAFDfvJqwLzgxwATnCgnp4wDl6Aa+Ax283gghmj+vj7feE2KBBRMW3FzOpLOADl0Isb5587h/U4gGvkt5v60Z1VLG8BhYjbzRwyQZemwAd6cCR5/XFWLYZRIMpX39AR0tjaGGiGzLVyhse5C9RKC6ai42ppWPKiBagOvaYk8lO7DajerabOZP46Lby5wKjw1HCRx7p9sVMOWGzb/vA1hwiWc6jm3MvQDTogQkiqIhJV0nBQBTU+3okKCFDy9WwferkHjtxib7t3xIUQtHxnIwtx4mpg26/HfwVNVDb4oI9RHmx5WGelRVlrtiw43zboCLaxv46AZeB3IlTkwouebTr1y2NjSpHz68WNFjHvupy3q8TFn3Hos2IAk4Ju5dCo8B3wP7VPr/FGaKiG+T+v+TQqIrOqMTL1VdWV1DdmcbO8KXBz6esmYWYKPwDL5b5FA1a0hwapHiom0r/cKaoqr+27/XcrS5UwSMbQAAAABJRU5ErkJggg==)](https://deepwiki.com/sgl-project/SpecForge) + +[![github badge](https://img.shields.io/badge/๐Ÿ“ƒ%20LMSYS-Blog-black.svg?style=flat)](https://lmsys.org/blog/2025-07-25-spec-forge/) +[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://sgl-fru7574.slack.com/archives/C09784E3EN6) +[![license](https://img.shields.io/badge/License-MIT%202.0-blue)](./LICENSE) + +
+ +## ๐Ÿ“ Overview + +SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference. + +We have seen many open-source projects for speculative decoding, but most of them are not well-maintained or not directly compatible with SGLang. We prepared this project because we wish that the open-source community can enjoy a speculative decoding framework that is +- regularly maintained by the SpecForge team: the code is runnable out-of-the-box +- directly compatible with SGLang: there is no additional efforts for porting to SGLang +- provide performant training capabilities: we provided online/offline/tensor-parallel/FSDP to suit your needs + + +Check out [**our documentation**](https://docs.sglang.ai/SpecForge/) to get started. + + +## ๐Ÿš€ Accelerate with SpecBundle + +SpecBundle is a collection of production-grade speculative decoding models that are released by the SpecForge team and our industry partners. They provide higher acceptance rate compared to the existing open-source checkpoints over a wide range of domains. Together with SGLang, you can experience up to 4x speedup for inference. Check out our resources below: + + +| Item | Link | +| --- | --- | +| ๐Ÿ“ Documentation | [Link](https://docs.sglang.io/SpecForge/community_resources/specbundle.html) | +| ๐Ÿ“Š Performance Dashboard | [Link](https://docs.sglang.io/SpecForge/SpecBundle/index.html) | +| ๐Ÿค— Hugging Face Collection | [Link](https://huggingface.co/collections/lmsys/specbundle) | + + +## ๐ŸŽ‰ News + +- [2025-12] ๐ŸŽ‰ Released SpecBundle (phase 1) and SpecForge v0.2. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-12-23-spec-bundle-phase-1/) +- [2025-12] ๐Ÿ”” Released the roadmap for 2026 Q1. +- [2025-08] ๐Ÿ”” SpecForge is listed as a [flagship project](https://lmsys.org/about/) in LMSYS. Congratulations to the SpecForge team! +- [2025-08] ๐Ÿ”ฅ SpecForge powered the Eagle3 draft model for GPT-OSS. Check out the blog at [LMSYS.org](https://lmsys.org/blog/2025-08-27-gpt-oss/) +- [2025-07] ๐Ÿ”ฅ SpecForge is released together with Llama4-Eagle3 checkpoints. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-07-25-spec-forge/) + +## โœจ Acknowledgements + +acknowledgements + +We would like to express our sincere gratitude to the official EAGLE team, especially Hongyang Zhang and Yuhui Li, for their invaluable contributions and support. Our thanks also go to the NVIDIA teamโ€”particularly Avery H and Izzy Puttermanโ€”and to the Google team, especially Ying Wang, for their insightful discussions and generous assistance throughout the project. + +We are especially grateful to Meituan for their strong backing and meaningful contributions, which played a vital role in driving this project forward. + +This project has also been inspired by many outstanding open-source projects from the LLM community, including [EAGLE](https://github.com/SafeAILab/EAGLE), [BaldEagle](https://github.com/NickL77/BaldEagle), and [TensorRT-Model-Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) and others. Their contributions and shared knowledge have greatly benefited our work. + +## ๐Ÿ’ก Special Thanks to Voltage Park + +We would like to extend our sincere thanks to [Voltage Park](https://www.voltagepark.com/), our official infrastructure partner. As part of a formal collaboration with the SGLang team, Voltage Park provided critical GPU resources that empowered us to train and evaluate large-scale speculative decoding models efficiently and reliably. This partnership was instrumental in making SpecForge possible. We deeply appreciate Voltage Parkโ€™s mission to make cutting-edge AI infrastructure more accessible, and we look forward to continued collaboration as we push the boundaries of open-source LLM serving and optimization. + +## ๐Ÿ“ƒ Citation + +```bibtex +@misc{specforge2025, + title={SpecForge: Train speculative decoding models effortlessly}, + author={Shenggui Li, Yikai Zhu, Chao Wang, Fan Yin, Shuai Shi, Yubo Wang, Yi Zhang, Yingyi Huang, Haoshuai Zheng, Yineng Zhang}, + year={2025}, + publisher={GitHub}, + howpublished={\url{https://github.com/sgl-project/specforge}}, +} diff --git a/idea1/pyproject.toml b/idea1/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..4698a5e811892b633345a09ba8c261fd67a1b10d --- /dev/null +++ b/idea1/pyproject.toml @@ -0,0 +1,47 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "specforge" +dynamic = ["version"] +readme = "README.md" +requires-python = ">=3.11" +description = "SpecForge: Speculative Decoding Training Framework" +authors = [{name = "SGLang Team"}] +urls = {Homepage = "https://github.com/sgl-project/SpecForge"} +dependencies = [ + "pre-commit", + "torch==2.9.1", + "torchaudio==2.9.1", + "torchvision==0.24.1", + "transformers==4.57.1", + "qwen-vl-utils==0.0.11", + "datasets", + "setuptools", + "tqdm", + "wandb", + "psutil", + "numpy", + "accelerate", + "pydantic", + "sglang==0.5.9", + "openai-harmony", + "ninja", + "packaging", + "yunchang", + "tensorboard", +] + +[tool.setuptools.packages.find] +exclude = ["configs*", "scripts*", "tests*"] + +[project.optional-dependencies] +dev = [ + "pre-commit", + "unittest" +] +fa = ["flash-attn"] + +[tool.setuptools.dynamic] +version = {file = "version.txt"} diff --git a/syxin/train_dflash_lora_inject.log b/syxin/train_dflash_lora_inject.log new file mode 100644 index 0000000000000000000000000000000000000000..4af7e198d9564fd7ae69e8eef1d792a3fdedcb10 --- /dev/null +++ b/syxin/train_dflash_lora_inject.log @@ -0,0 +1,105 @@ +nohup: ignoring input + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +`torch_dtype` is deprecated! Use `dtype` instead! +`torch_dtype` is deprecated! Use `dtype` instead! +`torch_dtype` is deprecated! Use `dtype` instead! +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +`torch_dtype` is deprecated! Use `dtype` instead! +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +`torch_dtype` is deprecated! Use `dtype` instead! +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +`torch_dtype` is deprecated! Use `dtype` instead! +`torch_dtype` is deprecated! Use `dtype` instead! +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. +`torch_dtype` is deprecated! Use `dtype` instead! +The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details. + Loading checkpoint shards: 0%| | 0/5 [00:00", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/run.py", line 940, in + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/run.py", line 936, in main + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/run.py", line 927, in run + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 156, in __call__ + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 284, in launch_agent + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py", line 138, in wrapper + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 717, in run + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 881, in _invoke_run + File "/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 85, in _terminate_process_handler +torch.distributed.elastic.multiprocessing.api.SignalException: Process 62173 got signal: 15 diff --git a/test/eval_accepted_length.md b/test/eval_accepted_length.md new file mode 100644 index 0000000000000000000000000000000000000000..30902c1206bb3e4bc9d9b147bf8d57c288422b03 --- /dev/null +++ b/test/eval_accepted_length.md @@ -0,0 +1,327 @@ +# DFlash-LoRA ่ฏ„ๆต‹๏ผšAccepted Length & Accuracy + +ๅฎŒๆ•ดๆญฅ้ชค๏ผš็”จ SGLang **่ตฐๆœบๅ™จๅ†…็ฝ‘ `10.1.1.72`** ๅฏๅŠจๆœๅŠก๏ผŒๅœจ +**HumanEval / MT-Bench / GSM8K** ไธ‰ไธช bench ไธŠๆต‹่ฏ•่ฎญ็ปƒๅฅฝ็š„ +`qwen3-8b-sft-32gpu` checkpoint ็š„ **accepted length** ๅ’Œ **accuracy**ใ€‚ + +--- + +## ๅŸบๆœฌไฟกๆฏ + +| ้กน็›ฎ | ่ทฏๅพ„ / ๅ€ผ | +|---|---| +| conda ็Žฏๅขƒ | `sglang` | +| ๅŸบๅบงๆจกๅž‹๏ผˆtarget๏ผ‰ | `/workspace/models/Qwen3-8B` | +| ่ฎญ็ปƒ่พ“ๅ‡บ๏ผˆๆœ€็ปˆ ckpt๏ผ‰ | `/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu/epoch_1_step_6000` | +| ๅˆๅนถๅŽ draft ๆจกๅž‹ | `/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu-merged` | +| Benchmark ่„šๆœฌ็›ฎๅฝ• | `/workspace/hanrui/syxin_old/Specforge/benchmarks/` | +| ๆœฌๅœฐๆ•ฐๆฎ้›† | `/workspace/hanrui/datasets/{humaneval,mtbench,gsm8k}` | +| ็ป“ๆžœ่พ“ๅ‡บ็›ฎๅฝ• | `/workspace/hanrui/syxin_old/Specforge/benchmarks/results/` | +| **ๆœบๅ™จๅ†…็ฝ‘ IP** | **`10.1.1.72`**๏ผˆ`hostname -I` ็กฎ่ฎค๏ผ‰ | +| GPU | 8 ร— H100 80GB | + +--- + +## Step 1๏ผšๅˆๅนถ LoRA ๆƒ้‡ + +DFlash-LoRA ่ฎญ็ปƒๅชไฟๅญ˜ไบ† adapter ๆƒ้‡๏ผŒSGLang ็š„ STANDALONE ๆŠ•ๆœบ่งฃ็ ้œ€่ฆไธ€ไธช +**ๅฎŒๆ•ด็‹ฌ็ซ‹็š„ๆจกๅž‹ๆ–‡ไปถ**ไฝœไธบ draft model๏ผŒๆ‰€ไปฅๅ…ˆ mergeใ€‚ + +```bash +conda activate sglang +python3 /workspace/hanrui/syxin_old/merge_lora.py +``` + +> ่€—ๆ—ถ็บฆ 3โ€“5 ๅˆ†้’Ÿ๏ผŒCPU ๅ†…ๅญ˜ๅ ็”จ โ‰ˆ 16 GBใ€‚ๅทฒๅญ˜ๅœจๅˆ™่‡ชๅŠจ่ทณ่ฟ‡ใ€‚ + +--- + +## Step 2๏ผšๅฏๅŠจ SGLang Server๏ผˆๅ†…็ฝ‘ + STANDALONE ๆŠ•ๆœบ่งฃ็ ๏ผ‰ + +**ๅผ€ไธ€ไธชๆ–ฐ็ปˆ็ซฏ๏ผˆ็ปˆ็ซฏ A๏ผ‰**๏ผŒๆ‰ง่กŒไปฅไธ‹ๅ‘ฝไปคใ€‚Server ไผšไธ€็›ดๅœจๅ‰ๅฐ่ฟ่กŒ๏ผŒไธ่ฆๅ…ณใ€‚ + +```bash +conda activate sglang +bash /workspace/hanrui/syxin_old/start_server.sh 8 +``` + +> ้ป˜่ฎค tp=8๏ผŒ็”จๅ…จ้ƒจ 8 ๅผ  H100ใ€‚ๅฆ‚้œ€ tp=4 ๆ”นไธบ `start_server.sh 4`ใ€‚ + +### ๅ‚ๆ•ฐ่ฏดๆ˜Ž + +| ๅ‚ๆ•ฐ | ่ฏดๆ˜Ž | +|---|---| +| `--host 10.1.1.72` | **ๅฟ…้กป็ป‘ๅฎšๅ†…็ฝ‘ IP**๏ผŒไธ่ƒฝ็”จ `127.0.0.1` ๆˆ– `0.0.0.0` | +| `--speculative-algorithm STANDALONE` | ไฝฟ็”จ็‹ฌ็ซ‹ draft model ๅšๆŠ•ๆœบ่งฃ็ ๏ผŒๆ˜ฏๆต‹ accepted length ็š„ๅ…ณ้”ฎ | +| `--speculative-draft-model-path` | merge ๅŽ็š„ DFlash-LoRA ๆจกๅž‹๏ผˆdraft๏ผ‰๏ผŒไธŽ target ๅ…ฑ็”จๅŒไธ€ๆ‰น GPU | +| `--speculative-num-steps 4` | draft model ๆฏ่ฝฎ็”Ÿๆˆ 4 ไธชๅ€™้€‰ token๏ผˆๅฏ่ฐƒ 3โ€“8๏ผ‰ | +| `--speculative-eagle-topk 1` | ๆฏๆญฅๅชไฟ็•™ๆฆ‚็އๆœ€้ซ˜็š„ 1 ไธชๅ€™้€‰๏ผˆ่ดชๅฟƒ๏ผŒไฟ่ฏ accepted length ๆŒ‡ๆ ‡ๅ‡†็กฎ๏ผ‰ | +| `--speculative-num-draft-tokens 4` | ๆฏๆฌก้ชŒ่ฏ 4 ไธช draft token | +| `--tp-size 4` | 4 ่ทฏๅผ ้‡ๅนถ่กŒ๏ผŒtarget + draft ๅ…ฑไบซๅŒ 4 ๅผ  H100 | +| `--mem-fraction-static 0.80` | ๆฏๅก 80% ๆ˜พๅญ˜็”จไบŽ้™ๆ€ KV cache | + +### ้ชŒ่ฏ Server ๅฐฑ็ปช๏ผˆ็ปˆ็ซฏ B๏ผ‰ + +```bash +curl http://10.1.1.72:30000/v1/models +``` + +่ฟ”ๅ›žๅซๆจกๅž‹ๅ็š„ JSON ๅณ่กจ็คบๅฐฑ็ปช๏ผŒๅฏไปฅ็ปง็ปญ Step 3ใ€‚ + +--- + +## Step 3๏ผš่ฟ่กŒ Benchmark + +**ๅœจ็ปˆ็ซฏ B ไธญๆ‰ง่กŒ**๏ผˆไฟๆŒ็ปˆ็ซฏ A ็š„ server ่ฟ่กŒ๏ผ‰ใ€‚ + +### ไธ‰ไธช Bench ไธ€ๆฌกๆ€งๅ…จ่ท‘๏ผˆๆŽจ่๏ผ‰ + +```bash +conda activate sglang +bash /workspace/hanrui/syxin_old/run_bench.sh +``` + +### ๅ•็‹ฌ่ท‘ๆŸไธช Bench + +```bash +conda activate sglang +bash /workspace/hanrui/syxin_old/run_bench.sh humaneval # ๅช่ท‘ HumanEval +bash /workspace/hanrui/syxin_old/run_bench.sh mtbench # ๅช่ท‘ MT-Bench +bash /workspace/hanrui/syxin_old/run_bench.sh gsm8k # ๅช่ท‘ GSM8K +bash /workspace/hanrui/syxin_old/run_bench.sh humaneval gsm8k # ไปปๆ„็ป„ๅˆ +``` + +็ป“ๆžœๆ—ฅๅฟ—ๅ’Œ jsonl ๆ–‡ไปถไฟๅญ˜ๅœจ `/workspace/hanrui/syxin_old/Specforge/benchmarks/results/`ใ€‚ + +--- + +## Step 4๏ผˆๅฏ้€‰๏ผ‰๏ผšๅฏนๆฏ” baseline๏ผˆๅŽŸๅง‹ Qwen3-8B๏ผŒๆ—  LoRA๏ผ‰ + +ๅ…ณๆމ Step 2 ็š„ server๏ผŒๆขไธ€ไธชๆ›ด็ฎ€ๅ•็š„ baseline server๏ผŒ็”จไบŽๅฏนๆฏ”ๆฒกๆœ‰ DFlash-LoRA ๆ—ถ็š„ accepted length๏ผš + +```bash +# ็ปˆ็ซฏ A๏ผšๅฏๅŠจ baseline server๏ผˆๆ— ๆŠ•ๆœบ่งฃ็ ๏ผ‰ +conda activate sglang + +python3 -m sglang.launch_server \ + --model-path /workspace/models/Qwen3-8B \ + --tp-size 4 \ + --mem-fraction-static 0.85 \ + --trust-remote-code \ + --host 10.1.1.72 \ + --port 30000 \ + --dtype bfloat16 +``` + +```bash +# ็ปˆ็ซฏ B๏ผš่ท‘ baseline bench +python3 bench_eagle3.py \ + --model-path $BASE_MODEL \ + --host $INTRANET_IP \ + --port $PORT \ + --config-list "1,0,0,0" \ + --benchmark-list "humaneval:164" "mtbench:80" "gsm8k:1319" \ + --output-dir $RESULT_DIR \ + --name baseline_qwen3_8b \ + --skip-launch-server +``` + +> `"1,0,0,0"` = batch 1๏ผŒๆ— ๆŠ•ๆœบ่งฃ็ ๏ผˆsteps=0๏ผ‰๏ผŒ`accept_length` ๅ›บๅฎšไธบ 1.0๏ผŒ +> ๅฏ็”จไบŽๅฏนๆฏ” accuracy ๆ˜ฏๅฆๅ›  LoRA ่ฎญ็ปƒ่€Œไธ‹้™ใ€‚ + +--- + +## ็ป“ๆžœๆ–‡ไปถ่ฏดๆ˜Ž + +็ป“ๆžœไฟๅญ˜ๅœจ `$RESULT_DIR/` ไธ‹๏ผŒๆ–‡ไปถๅ็คบไพ‹๏ผš +``` +dflash_lora_all_results_20260307_123456.jsonl +``` + +ๅ…ณ้”ฎๅญ—ๆฎต๏ผš + +```json +{ + "humaneval": [{ + "batch_size": 1, "steps": 4, "topk": 1, "num_draft_tokens": 4, + "metrics": [{ + "latency": 45.2, + "output_throughput": 312.5, + "accept_length": 2.73, โ† ๆŠ•ๆœบ่งฃ็ ๅนณๅ‡ๆŽฅๅ—้•ฟๅบฆ๏ผˆ่ถŠ้ซ˜่ถŠๅฅฝ๏ผŒ1.0=ๆ— ๆ•ˆ๏ผ‰ + "accuracy": 0.756, โ† pass@1๏ผˆHumanEval๏ผ‰/ ๆ•ฐๅ€ผๅ‡†็กฎ็އ๏ผˆGSM8K๏ผ‰/ null๏ผˆMTBench๏ผ‰ + "num_questions": 164 + }] + }], + "mtbench": [ ... ], + "gsm8k": [ ... ] +} +``` + +| ๅญ—ๆฎต | ๅซไน‰ | +|---|---| +| `accept_length` | ๅนณๅ‡ๆฏๆฌก verify ๆŽฅๅ—็š„ token ๆ•ฐใ€‚`> 1.0` ่ฏดๆ˜Ž draft model ๆœ‰ๆ•ˆ๏ผŒ่ถŠ้ซ˜่ถŠๅฅฝ | +| `accuracy` | HumanEval: pass@1๏ผ›GSM8K: ๆ•ฐๅ€ผ็ญ”ๆกˆๅ‡†็กฎ็އ๏ผ›MT-Bench: `null` | +| `output_throughput` | tokens/s๏ผˆๅซๆŠ•ๆœบๅŠ ้€Ÿ๏ผ‰ | +| `latency` | ๆ•ดไธช bench ๆ€ป่€—ๆ—ถ๏ผˆ็ง’๏ผ‰ | + +--- + +## ไธ€้”ฎ่„šๆœฌ๏ผˆmerge + server + bench + ๅ…ณserver๏ผ‰ + +ๅฐ†ไปฅไธ‹ๅ†…ๅฎนไฟๅญ˜ไธบ `/workspace/hanrui/syxin_old/run_eval.sh`๏ผš + +```bash +#!/bin/bash +set -e + +# ===== ้…็ฝฎ ===== +INTRANET_IP=10.1.1.72 +PORT=30000 +BASE_MODEL=/workspace/models/Qwen3-8B +CKPT=epoch_1_step_6000 +ADAPTER=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu/${CKPT} +MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu-merged +BENCH_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks +RESULT_DIR=$BENCH_DIR/results +TP=4 +# ================ + +conda activate sglang +export PYTHONPATH=/workspace/hanrui/syxin_old/Specforge:$PYTHONPATH +mkdir -p $RESULT_DIR + +# ---- Step 1: merge LoRA ---- +if [ ! -d "$MERGED" ]; then + echo ">>> Merging LoRA ..." + python3 - <>> Merged model exists, skip merge." +fi + +# ---- Step 2: launch server ---- +echo ">>> Starting SGLang server on $INTRANET_IP:$PORT ..." +python3 -m sglang.launch_server \ + --model-path $BASE_MODEL \ + --speculative-algorithm STANDALONE \ + --speculative-draft-model-path $MERGED \ + --speculative-num-steps 4 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --tp-size $TP \ + --mem-fraction-static 0.80 \ + --trust-remote-code \ + --host $INTRANET_IP \ + --port $PORT \ + --dtype bfloat16 \ + 2>&1 | tee $RESULT_DIR/server.log & +SERVER_PID=$! + +echo ">>> Waiting for server (up to 120s) ..." +for i in $(seq 1 24); do + curl -s http://$INTRANET_IP:$PORT/v1/models > /dev/null 2>&1 && { echo ">>> Server ready!"; break; } + sleep 5 +done + +# ---- Step 3: benchmarks ---- +cd $BENCH_DIR + +echo ">>> HumanEval ..." +python3 bench_eagle3.py \ + --model-path $BASE_MODEL \ + --speculative-draft-model-path $MERGED \ + --host $INTRANET_IP --port $PORT \ + --config-list "1,4,1,4" \ + --benchmark-list "humaneval:164" \ + --output-dir $RESULT_DIR --name ${CKPT}_humaneval \ + --skip-launch-server 2>&1 | tee $RESULT_DIR/humaneval.log + +echo ">>> MT-Bench ..." +python3 bench_eagle3.py \ + --model-path $BASE_MODEL \ + --speculative-draft-model-path $MERGED \ + --host $INTRANET_IP --port $PORT \ + --config-list "1,4,1,4" \ + --benchmark-list "mtbench:80" \ + --output-dir $RESULT_DIR --name ${CKPT}_mtbench \ + --skip-launch-server 2>&1 | tee $RESULT_DIR/mtbench.log + +echo ">>> GSM8K ..." +python3 bench_eagle3.py \ + --model-path $BASE_MODEL \ + --speculative-draft-model-path $MERGED \ + --host $INTRANET_IP --port $PORT \ + --config-list "1,4,1,4" \ + --benchmark-list "gsm8k:1319" \ + --output-dir $RESULT_DIR --name ${CKPT}_gsm8k \ + --skip-launch-server 2>&1 | tee $RESULT_DIR/gsm8k.log + +# ---- Step 4: shutdown ---- +echo ">>> Shutting down server (PID $SERVER_PID) ..." +kill $SERVER_PID 2>/dev/null || true +wait $SERVER_PID 2>/dev/null || true +echo ">>> All done. Results in $RESULT_DIR" +ls -lh $RESULT_DIR/*.jsonl 2>/dev/null +``` + +่ฟ่กŒ๏ผš + +```bash +chmod +x /workspace/hanrui/syxin_old/run_eval.sh +bash /workspace/hanrui/syxin_old/run_eval.sh 2>&1 | tee /workspace/hanrui/syxin_old/eval.log +``` + +--- + +## ๅธธ่ง้—ฎ้ข˜ + +### Q1๏ผšaccept_length ๅง‹็ปˆๆ˜ฏ 1.0 + +Server ๆฒกๆœ‰ๅผ€ๅฏๆŠ•ๆœบ่งฃ็ ใ€‚็กฎ่ฎค server ๅฏๅŠจๆ—ถๆœ‰ `--speculative-algorithm STANDALONE`๏ผŒ +ไธ” `--speculative-draft-model-path` ๆŒ‡ๅ‘ **merge ๅŽ็š„ๅฎŒๆ•ดๆจกๅž‹**๏ผˆไธๆ˜ฏ adapter ็›ฎๅฝ•๏ผ‰ใ€‚ + +### Q2๏ผšConnection refused / ่ฟžๆŽฅ่ถ…ๆ—ถ + +- ็กฎ่ฎค server `--host` ๆ˜ฏ `10.1.1.72`๏ผˆไธๆ˜ฏ `127.0.0.1` ๆˆ– `0.0.0.0`๏ผ‰ +- bench ๅ‘ฝไปค้‡Œ `--host` ไนŸๆ˜ฏ `10.1.1.72` +- `bench_eagle3.py` ๅทฒไฟฎๅค `base_url = f"http://{args.host}:{args.port}"`๏ผˆๅŽŸๆฅ็กฌ็ผ–็  `localhost`๏ผ‰ + +### Q3๏ผšๆ•ฐๆฎ้›†ไธ‹่ฝฝๅคฑ่ดฅ๏ผˆๆ— ๅค–็ฝ‘๏ผ‰ + +ไธ‰ไธช benchmarker ๅทฒๆ”นไธบไผ˜ๅ…ˆ่ฏปๆœฌๅœฐๆ–‡ไปถ๏ผš + +| bench | ๆœฌๅœฐๆ–‡ไปถ | +|---|---| +| GSM8K | `/workspace/hanrui/datasets/gsm8k/test.jsonl` | +| MT-Bench | `/workspace/hanrui/datasets/mtbench/question.jsonl` | +| HumanEval | `/workspace/hanrui/datasets/humaneval/test.jsonl` | + +### Q4๏ผšOOM + +- ๅ‡ๅฐ `--mem-fraction-static`๏ผˆ่ฏ• `0.70`๏ผ‰ +- ๅ‡ๅฐ `--tp-size`๏ผˆ่ฏ• `2`๏ผŒไฝ†ๆ›ดๆ…ข๏ผ‰ +- ๅ‡ๅฐ‘ `--speculative-num-steps`๏ผˆ่ฏ• `3`๏ผ‰ + +### Q5๏ผšๅฆ‚ไฝ•ๆต‹ๅ…ถไป– checkpoint + +ไฟฎๆ”น `CKPT` ๅ˜้‡๏ผŒ้‡ๆ–ฐ merge๏ผˆไฟๅญ˜ๅˆฐไธๅŒ็›ฎๅฝ•๏ผ‰๏ผš + +```bash +CKPT=epoch_2_step_15000 +ADAPTER=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu/${CKPT} +MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-sft-32gpu-merged-${CKPT} +# ้‡ๆ–ฐ merge ๅŽ้‡ๅฏ server ๅณๅฏ +``` + +--- + +*ๅ†…็ฝ‘ IP๏ผš`10.1.1.72` | ๅŸบๅบง๏ผš`/workspace/models/Qwen3-8B` | ๆœ€็ปˆ ckpt๏ผš`epoch_1_step_6000`*