File size: 7,596 Bytes
bb1a7f0 70a1c1b bb1a7f0 3b29770 bb1a7f0 3b29770 7b1caad 02856a3 7b1caad b57c984 7b1caad b57c984 7b1caad 01e78cf 7b1caad 01e78cf 7b1caad c404867 ffc21f7 c404867 896b291 c404867 0b86149 896b291 c7cf442 896b291 01e78cf 896b291 ec58051 896b291 01e78cf 00f4710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
---
language:
- en
license: cc-by-sa-4.0
datasets:
- euclaise/TinyCoT
- euclaise/reddit-instruct-curated
- sablo/oasst2_curated
---
# ReMask: Improving autoregressive language models via regularized masking
## Background
[Self-Play Finetuning (SPIN)](https://arxiv.org/abs/2401.01335) is a recent finetuning method which outperforms standard supervised finetuning (SFT).
Instead of just performing next-token prediction, SPIN it an iterative method which contrasts generations from the previous iteration of the model with the ground-truth completions.
Unlike methods like reinforcement learning or ranking losses, SPIN does not require preference data, which makes it an attractive method since preference data can be hard to gather.
However, SPIN's popularity has been limited by the need to repeatedly generate sequences from the model -- generation is much slower than training, so SPIN is much more slow and expensive compared to SFT.
With this problem in mind, I sought out to create an alternative to SPIN which doesn't require generation.
### Why does SPIN work?
SFT trains models to predict the next token given all the ground-truth previous tokens.
However, in generation, the model doesn't have access to a ground-truth to predict from, and instead repeatedly predicts on top of its own predictions.
This creates a bias known as "exposure bias": Models often can pick reasonable choices for the next token on average, but can't keep this up for the full sequence.
In particular, it might be easy to predict a *reasonable* next token, but much more difficult to predict the full sequence.
***For instance, consider the following case:***
> The astronomer pointed his telescope at the distant star, hoping to see
The correct prediction here might be "signs of life.". However, the model might predict "and" rather than "signs", since "and" is *reasonable* in the immediate context - it's gramatically correct, but implies a strange ending to the sentence.
As a result, the model might end up with something like *"The astronomer pointed his telescope at the distant star, hoping to see and hear."* - which makes little sense.
SPIN's advantage over SFT likely comes from its partial mitigation of exposure bias.
SPIN doesn't only train the model to predict the next token accurately, it repeatedly trains the model to identify and fix discrepancies between its generations and the ground-truth.
In order to do this, the model must implicitly learn to think ahead, as exposure bias is likely what causes many of the discrepancies.
### How can we simplify this?
Unfortunately, explicitly predicting ahead for many steps is very expensive, and considering full model generations requires a slow generation process.
An obvious option is to simply randomly corrupt tokens in the sequence.
The model must keep an internal estimate of what the corrupted tokens ought to be in order to predict the token after them, forcing the model to think ahead.
The most obvious ways to do this are to randomly replace input tokens with a special `[mask]` token, or to randomly replace input tokens with other random tokens.
These approaches were tried in [Masked Thought](https://arxiv.org/abs/2403.02178), albeit with somewhat different motivations.
However, these approaches have a problem: Models can detect when a token is `[mask]` or is highly unlikely, so the model may only learn to think ahead when the corruptions are present.
To avoid this issue, we can run the model twice - once with a masked sequence, and once on the full sequence.
Then, we penalize deviations between these two runs, which forces the model to act the same regardless of if the `[mask]` token is present or not.
This approach was initially introduced with [R-TeaFor](https://aclanthology.org/2022.emnlp-main.423/) for abstractive summarization, but can be easily applied to standard generation tasks too.
### ReMask and ReMask-CoT:
ReMask applies an approach similar to R-TeaFor to typical chat/instruction tuning.
Consider the following chat interaction:
> User: What is 1+1?
>
> Assistant: **1+1=2**
>
> **User:**
The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
We then compute a divergence loss `D(p_masked, p_full)` between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
```
loss = 0.5*(CE(p_masked, labels) + CE(p_full, labels)) + weight*D(p_masked, p_full)
```
***ReMask-CoT:***
For CoT tasks where the reasoning is explicitly separated from the answer, we can add some further improvements.
First, note that CoT rationales are noisy -- there are many correct rationales which might lead to the same correct answer, and rationales are impacted by things like writing style which don't matter for the actual correctness of the reasoning.
Keeping this in mind:
- We also randomly mask a small portion of the labels of the rationale, but not the answer, such that an accurate answer is more important than a rationale that is word-for-word identical to the annotated rationale.
- The exact answer is always important and is always a few tokens. Hence, we do not mask the labels or input tokens for the answer value.
- Rarely, we ignore the rationale labels entirely, such that the model is only pushed to learn what leads to the best answer.
## Results
I trained StableLM-3B-4e1t repeatedly on [TinyCoT](https://huggingface.co/datasets/euclaise/TinyCoT), along with 1000 examples from [reddit-instruct-curated](https://huggingface.co/datasets/euclaise/reddit-instruct-curated) and 1000 examples from [oasst2-curated](https://huggingface.co/datasets/sablo/oasst2_curated).
I trained once with ReMask/ReMask-CoT, once without regularization to match Masked Thought (w/ partial label-masking for CoT), and once with SFT.
If my hypothesis regarding exposure bias is correct, ReMask should significantly improve generative benchmarks like GSM8K, but would not necessarily improve logprob-based benchmarks like ARC-c (as implemented by the evaluation harness):
Here are some benchmark results, computed using the the LM Evaluation Harness with vllm:
| Model | GSM8K (strict, 5-shot) | ARC-c (acc_norm, 25-shot) |
|:--------------:|-----------------------:|--------------------------:|
| SFT | 24.34% | 42.92% |
| Masked Thought | 24.18% | *43.60%* |
| **ReMask** | **27.90%** | 43.26% |
As I expected, it improves GSM8K, but doesn't do much to ARC.
## Training details
- Framework: PyTorch Lightning
- Optimizer: [Lilith](https://github.com/euclaise/supertrainer2000/blob/master/src/supertrainer2k/optim/lilith.py)
- Training sequence length: 256
- Input masking probability: 40%
- Label masking probability: 10%
- Answer-only (full rationale label masking) probability: 10%
- Batch size: 16, accumulated to 256
- Epochs: 6
- Learning rate: 1e-5
- Learning rate schedule: One Cycle, cosine, no cycle_momentum
- Regularization weight: 0.1
## Prompt format
The format for reddit-instruct and oasst2 was:
```
<|user|>
[insert instruction here]
<|assistant|>
[insert response here]
<|user|>
...
```
The format for TinyCoT was:
```
<|user|>
[insert instruction here]
<|rationale|>
[insert reasoning here]
<|answer|>
[insert direct answer here]
```
|