File size: 17,949 Bytes
711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 66a0a6c 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 711e926 9a95223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
---
language:
- en
license: llama3
library_name: transformers
tags:
- mathematics
datasets:
- hkust-nlp/dart-math-hard
metrics:
- accuracy
pipeline_tag: text-generation
base_model: meta-llama/Meta-Llama-3-8B
model-index:
- name: dart-math-llama3-8b-prop2diff
results:
- task:
type: text-generation
name: Mathematical Problem-Solving
dataset:
type: hendrycks/competition_math
name: MATH
split: test
metrics:
- type: accuracy
name: Pass@1 (0-shot CoT)
value: 46.6
- task:
type: text-generation
name: Mathematical Problem-Solving
dataset:
type: openai/gsm8k
name: GSM8K
config: main
split: test
metrics:
- type: accuracy
name: Pass@1 (0-shot CoT)
value: 81.1
- task:
type: text-generation
name: Mathematical Problem-Solving
dataset:
type: college-math
name: CollegeMath
metrics:
- type: accuracy
name: Pass@1 (0-shot CoT)
value: 28.8
- task:
type: text-generation
name: Mathematical Problem-Solving
dataset:
type: deepmind-mathematics
name: DeepMind-Mathematics
metrics:
- type: accuracy
name: Pass@1 (0-shot CoT)
value: 48.0
- task:
type: text-generation
name: Mathematical Problem-Solving
dataset:
type: Hothan/OlympiadBench
name: OlympiadBench-OE_TO_maths_en_COMP
config: OE_TO_maths_en_COMP
split: train
metrics:
- type: accuracy
name: Pass@1 (0-shot CoT)
value: 14.5
- task:
type: text-generation
name: Mathematical Problem-Solving
dataset:
type: TIGER-Lab/TheoremQA
name: TheoremQA
split: test
metrics:
- type: accuracy
name: Pass@1 (0-shot CoT)
value: 19.4
---
# DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving
📝 [Paper@arXiv](https://arxiv.org/abs/2407.13690) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/hkust-nlp/dart-math-665704599b35de59f8fdf6c1) | 🐱 [Code@GitHub](https://github.com/hkust-nlp/dart-math)
🐦 [Thread@X(Twitter)](https://x.com/tongyx361/status/1811413243350454455) | 🐶 [中文博客@知乎](https://zhuanlan.zhihu.com/p/708371895) | 📊 [Leaderboard@PapersWithCode](https://paperswithcode.com/paper/dart-math-difficulty-aware-rejection-tuning#results) | 📑 [BibTeX](https://github.com/hkust-nlp/dart-math?tab=readme-ov-file#citation)
## Models: `DART-Math`
`DART-Math` models achieve performance **superior or competitive to previous SOTAs** on 2 in-domain and 4 challenging out-of-domain mathematical reasoning benchmarks, despite using **much smaller datasets** and **no proprietary model like GPT-4**.
| Model | [MATH](https://huggingface.co/datasets/hendrycks/competition_math) | [GSM8K](https://huggingface.co/datasets/gsm8k) | [College](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/mwpbench/college-math-test.jsonl) | [DM](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/deepmind-mathematics.json) | [Olympiad](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/olympiadbench/OE_TO_maths_en_COMP.json) | [Theorem](https://github.com/hkust-nlp/dart-math/tree/main/data/eval-dsets/theoremqa.json) | AVG |
| :----------------------------------------------------------------------------------------------------- | -----------------------------------------------------------------: | ---------------------------------------------: | -----------------------------------------------------------------------------------------------------------: | -----------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------: | -----------------------------------------------------------------------------------------: | -------: |
| GPT-4 (0314) | [52.6](https://arxiv.org/abs/2403.04706) | [94.7](https://arxiv.org/abs/2403.04706) | [24.4](https://arxiv.org/abs/2403.02884) | -- | -- | -- | -- |
| Llama-3-70B-MetaMath | 44.9 | 88.0 | 31.9 | 53.2 | 11.6 | 21.9 | 41.9 |
| [`DART-Math-Llama-3-70B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-llama3-70b-uniform) | 54.9 | **90.4** | **38.5** | **64.1** | 19.1 | 27.4 | 49.1 |
| [`DART-Math-Llama-3-70B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-llama3-70b-prop2diff) | **56.1** | 89.6 | 37.9 | **64.1** | **20.0** | **28.2** | **49.3** |
| DeepSeekMath-7B-MetaMath | 43.7 | 81.8 | 33.7 | 53.0 | 13.6 | 23.2 | 41.5 |
| [DeepSeekMath-7B-RL](https://huggingface.co/deepseek-ai/deepseek-math-7b-rl) | 53.1 | 88.4 | 41.3 | 58.3 | 18.7 | 35.9 | 49.3 |
| [`DART-Math-DSMath-7B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-dsmath-7b-uniform) | 52.9 | **88.2** | 40.1 | 60.2 | 21.3 | **32.5** | 49.2 |
| [`DART-Math-DSMath-7B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-dsmath-7b-prop2diff) | **53.6** | 86.8 | **40.7** | **61.6** | **21.7** | 32.2 | **49.4** |
| Mistral-7B-MetaMath | 29.8 | 76.5 | 19.3 | 28.0 | 5.9 | 14.0 | 28.9 |
| [`DART-Math-Mistral-7B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-mistral-7b-uniform) | 43.5 | **82.6** | 26.9 | 42.0 | 13.2 | 16.4 | 27.4 |
| [`DART-Math-Mistral-7B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-mistral-7b-prop2diff) | **45.5** | 81.1 | **29.4** | **45.1** | **14.7** | **17.0** | **38.8** |
| Llama-3-8B-MetaMath | 32.5 | 77.3 | 20.6 | 35.0 | 5.5 | 13.8 | 30.8 |
| [`DART-Math-Llama-3-8B` (Uniform)](https://huggingface.co/hkust-nlp/dart-math-llama3-8b-uniform) | 45.3 | **82.5** | 27.1 | **48.2** | 13.6 | 15.4 | 38.7 |
| [`DART-Math-Llama-3-8B` (Prop2Diff)](https://huggingface.co/hkust-nlp/dart-math-llama3-8b-prop2diff) | **46.6** | 81.1 | **28.8** | 48.0 | **14.5** | **19.4** | **39.7** |
***Abbreviations**: College (CollegeMath), DM (DeepMind Mathematics), Olympiad (OlympiadBench-Math), Theorem (TheoremQA).
**Bold** means the best score by SFT on the respective base model here.
To reproduce our results, please refer to [the `DART-Math` GitHub repository](https://github.com/hkust-nlp/dart-math).*
## Prompt Template
All the `DART-Math` models use the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) prompt template:
```
Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n###Instruction:\n{query}\n\n### Response:\n
```
## Training Dataset
We construct our traning datasets by applying **Difficulty-Aware Rejection Sampling** (`DARS`) to the **MATH and GSM8K** training sets.
`DARS` tackle **severe biases towards easy queries, with frequent failures to generate any correct response for the most challenging queries**, in previous datasets.
These biases are primarily caused by vanilla rejection sampling, where **the same number of responses is
sampled for each query**, yet the likelihood of obtaining correct responses for difficult queries is significantly lower, sometimes even zero.
Please refer to [`DART-Math-Hard`](https://huggingface.co/datasets/hkust-nlp/dart-math-hard) / [`DART-Math-Uniform`](https://huggingface.co/datasets/hkust-nlp/dart-math-uniform) for more details.
## Training Setup
We perform standard instruction tuning to several base models including Llama3-8B & Mistral-7B & Llama3-70B as representatives of general models and DeepSeekMath-
7B as the representative of math-specialized model
on our synthetic datasets [`DART-Math-Hard`](https://huggingface.co/datasets/hkust-nlp/dart-math-hard) & [`DART-Math-Uniform`](https://huggingface.co/datasets/hkust-nlp/dart-math-uniform),
leading to `DART-Math (Prop2Diff)` & `DART-Math (Uniform)` respectively.
For simplicity, we keep most hyper-parameters the same across different models and datasets:
- Model max length (of [packed](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing) sequence): 4096
- Batch size: 64
- Warm-up ratio: 0.03
- Learning rate scheduler: cosine
- Prompt template: [Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
Several other key hyper-parameters are tuned as follow:
| Base Model | Max. L.R. | # of Epochs | # of Grad. Acc. Steps | # of A100 GPUs |
|:--------------- | ---------:| -----------:| ---------------------:| --------------:|
| Mistral-7B | `1e-5` | 3 | 1 | 8 |
| Llama3-8B | `5e-5` | 1 | 2 | 8 |
| Llama3-70B | `2e-5` | 1 | 1 | 32 |
| DeepSeekMath-7B | `5e-5` | 3 | 1 | 8 |
- For **maximum learning rate**, we determine the values by **searching** through `1e-6,5e-6,1e-5,2e-5,5e-5,1e-4` according to the MATH performance after training on MMIQC for 1 epoch, except for Llama3-70B that is so expensive to search for that we derive from Llama3-8B’s learning rate in analogy to the relationship of (per-training) learning rates between [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf) and [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-hf) (\~2:1).
- For **Llama3** models, preliminary experiments indicate that **training for 1 epoch consistently outperforms 3 epochs**.
Please refer to [Appendix A.1 of our paper](https://tongyx361.github.io/assets/dart-math/paper-dart-math.pdf) for more details.
## Other Details
- For Mistral-7B-based models, we disable `sliding_window` by default following [the newest Mistral-7B-Instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/blob/main/config.json) (Flash Attention 2 does not support `sliding_window` and XFormer backend in vLLM has throughput \~10% lower in our experiments.)
## Citation
If you find our data, model or code useful for your work, please kindly cite [our paper](https://arxiv.org/abs/2407.13690):
```latex
@article{tong2024dartmath,
title={DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving},
author={Yuxuan Tong and Xiwen Zhang and Rui Wang and Ruidong Wu and Junxian He},
year={2024},
eprint={2407.13690},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2407.13690},
}
```
|