# DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving

📝 Paper@arXiv | 🤗 Datasets&Models@HF | 🐱 Code@GitHub

🐦 Thread@X(Twitter) | 🐶 中文博客@知乎 | 📊 Leaderboard@PapersWithCode | 📑 BibTeX

##
Models: `DART-Math`

`DART-Math`

models achieve performance **superior or competitive to previous SOTAs** on 2 in-domain and 4 challenging out-of-domain mathematical reasoning benchmarks, despite using **much smaller datasets** and **no proprietary model like GPT-4**.

Model | MATH | GSM8K | College | DM | Olympiad | Theorem | AVG |
---|---|---|---|---|---|---|---|

GPT-4 (0314) | 52.6 | 94.7 | 24.4 | -- | -- | -- | -- |

Llama-3-70B-MetaMath | 44.9 | 88.0 | 31.9 | 53.2 | 11.6 | 21.9 | 41.9 |

`DART-Math-Llama-3-70B` (Uniform) |
54.9 | 90.4 |
38.5 |
64.1 |
19.1 | 27.4 | 49.1 |

`DART-Math-Llama-3-70B` (Prop2Diff) |
56.1 |
89.6 | 37.9 | 64.1 |
20.0 |
28.2 |
49.3 |

DeepSeekMath-7B-MetaMath | 43.7 | 81.8 | 33.7 | 53.0 | 13.6 | 23.2 | 41.5 |

DeepSeekMath-7B-RL | 53.1 | 88.4 | 41.3 | 58.3 | 18.7 | 35.9 | 49.3 |

`DART-Math-DSMath-7B` (Uniform) |
52.9 | 88.2 |
40.1 | 60.2 | 21.3 | 32.5 |
49.2 |

`DART-Math-DSMath-7B` (Prop2Diff) |
53.6 |
86.8 | 40.7 |
61.6 |
21.7 |
32.2 | 49.4 |

Mistral-7B-MetaMath | 29.8 | 76.5 | 19.3 | 28.0 | 5.9 | 14.0 | 28.9 |

`DART-Math-Mistral-7B` (Uniform) |
43.5 | 82.6 |
26.9 | 42.0 | 13.2 | 16.4 | 27.4 |

`DART-Math-Mistral-7B` (Prop2Diff) |
45.5 |
81.1 | 29.4 |
45.1 |
14.7 |
17.0 |
38.8 |

Llama-3-8B-MetaMath | 32.5 | 77.3 | 20.6 | 35.0 | 5.5 | 13.8 | 30.8 |

`DART-Math-Llama-3-8B` (Uniform) |
45.3 | 82.5 |
27.1 | 48.2 |
13.6 | 15.4 | 38.7 |

`DART-Math-Llama-3-8B` (Prop2Diff) |
46.6 |
81.1 | 28.8 |
48.0 | 14.5 |
19.4 |
39.7 |

**Abbreviations**: College (CollegeMath), DM (DeepMind Mathematics), Olympiad (OlympiadBench-Math), Theorem (TheoremQA).
**Bold** means the best score by SFT on the respective base model here.
To reproduce our results, please refer to the `DART-Math`

GitHub repository.

## Prompt Template

All the `DART-Math`

models use the Alpaca prompt template:

```
Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n###Instruction:\n{query}\n\n### Response:\n
```

## Training Dataset

We construct our traning datasets by applying **Difficulty-Aware Rejection Sampling** (`DARS`

) to the **MATH and GSM8K** training sets.

`DARS`

tackle **severe biases towards easy queries, with frequent failures to generate any correct response for the most challenging queries**, in previous datasets.

These biases are primarily caused by vanilla rejection sampling, where **the same number of responses is
sampled for each query**, yet the likelihood of obtaining correct responses for difficult queries is significantly lower, sometimes even zero.

Please refer to `DART-Math-Hard`

/ `DART-Math-Uniform`

for more details.

## Training Setup

We perform standard instruction tuning to several base models including Llama3-8B & Mistral-7B & Llama3-70B as representatives of general models and DeepSeekMath-
7B as the representative of math-specialized model
on our synthetic datasets `DART-Math-Hard`

& `DART-Math-Uniform`

,
leading to `DART-Math (Prop2Diff)`

& `DART-Math (Uniform)`

respectively.

For simplicity, we keep most hyper-parameters the same across different models and datasets:

- Model max length (of packed sequence): 4096
- Batch size: 64
- Warm-up ratio: 0.03
- Learning rate scheduler: cosine
- Prompt template: Alpaca

Several other key hyper-parameters are tuned as follow:

Base Model | Max. L.R. | # of Epochs | # of Grad. Acc. Steps | # of A100 GPUs |
---|---|---|---|---|

Mistral-7B | `1e-5` |
3 | 1 | 8 |

Llama3-8B | `5e-5` |
1 | 2 | 8 |

Llama3-70B | `2e-5` |
1 | 1 | 32 |

DeepSeekMath-7B | `5e-5` |
3 | 1 | 8 |

- For
**maximum learning rate**, we determine the values by**searching**through`1e-6,5e-6,1e-5,2e-5,5e-5,1e-4`

according to the MATH performance after training on MMIQC for 1 epoch, except for Llama3-70B that is so expensive to search for that we derive from Llama3-8B’s learning rate in analogy to the relationship of (per-training) learning rates between Llama2-7B and Llama2-70B (~2:1). - For
**Llama3**models, preliminary experiments indicate that**training for 1 epoch consistently outperforms 3 epochs**.

Please refer to Appendix A.1 of our paper for more details.

## Other Details

- For Mistral-7B-based models, we disable
`sliding_window`

by default following the newest Mistral-7B-Instruct (Flash Attention 2 does not support`sliding_window`

and XFormer backend in vLLM has throughput ~10% lower in our experiments.)

## Citation

If you find our data, model or code useful for your work, please kindly cite our paper:

```
@article{tong2024dartmath,
title={DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving},
author={Yuxuan Tong and Xiwen Zhang and Rui Wang and Ruidong Wu and Junxian He},
year={2024},
eprint={2407.13690},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2407.13690},
}
```

- Downloads last month
- 20

## Model tree for hkust-nlp/dart-math-llama3-70b-prop2diff

Base model

meta-llama/Meta-Llama-3-70B## Dataset used to train hkust-nlp/dart-math-llama3-70b-prop2diff

## Collection including hkust-nlp/dart-math-llama3-70b-prop2diff

## Evaluation results

- Pass@1 (0-shot CoT) on MATHtest set self-reported56.100
- Pass@1 (0-shot CoT) on GSM8Ktest set self-reported89.600
- Pass@1 (0-shot CoT) on CollegeMathself-reported37.900
- Pass@1 (0-shot CoT) on DeepMind-Mathematicsself-reported64.100
- Pass@1 (0-shot CoT) on OlympiadBench-OE_TO_maths_en_COMPself-reported20.000
- Pass@1 (0-shot CoT) on TheoremQAtest set self-reported28.200