Text Generation
Transformers
Safetensors
English
stablelm
conversational
Inference Endpoints
euclaise commited on
Commit
7b1caad
1 Parent(s): 3b29770

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -1
README.md CHANGED
@@ -8,4 +8,83 @@ language:
8
  - en
9
  ---
10
 
11
- # ReMask: Improving language model abilities via regularized masking
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - en
9
  ---
10
 
11
+ # ReMask: Improving autoregressive language models via regularized masking
12
+
13
+ ## Background
14
+
15
+ [Self-Play Finetuning (SPIN)](https://arxiv.org/abs/2401.01335) is a recent finetuning method which outperforms standard supervised finetuning (SFT).
16
+ Instead of just performing next-token prediction, SPIN it an iterative method which contrasts generations from the previous iteration of the model with the ground-truth completions.
17
+ Unlike methods like reinforcement learning or ranking losses, SPIN does not require preference data, which makes it an attractive method since preference data can be hard to gather.
18
+ However, SPIN's popularity has been limited by the need to repeatedly generate sequences from the model -- generation is much slower than training, so SPIN is much more slow and expensive compared to SFT.
19
+
20
+ With this problem in mind, I sought out to create an alternative to SPIN which doesn't require generation.
21
+
22
+ ### Why does SPIN work?
23
+
24
+ SFT trains models to predict the next token given all the ground-truth previous tokens.
25
+ However, in generation, the model doesn't have access to a ground-truth to predict from, and instead repeatedly predicts on top of its own predictions.
26
+ This creates a bias known as "exposure bias": Models often can pick reasonable choices for the next token on average, but can't keep this up for the full sequence.
27
+ In particular, it might be easy to predict a *reasonable* next token, but much more difficult to predict the full sequence.
28
+
29
+ ***For instance, consider the following case:***
30
+
31
+ > The astronomer pointed his telescope at the distant star, hoping to see
32
+
33
+ The correct prediction here might be "signs of life.". However, the model might predict "and" rather than "signs", since "and" is *reasonable* in the immediate context - it's gramatically correct, but implies a strange ending to the sentence.
34
+ As a result, the model might end up with something like "The astronomer pointed his telescope at the distant star, hoping to see and hear." - which makes little sense.
35
+
36
+ ---
37
+
38
+ SPIN's advantage over SFT likely comes from its partial mitigation of exposure bias.
39
+ SPIN doesn't only train the model to predict the next token accurately, it repeatedly trains the model to identify and fix discrepancies between its generations and the ground-truth.
40
+ In order to do this, the model must implicitly learn to think ahead, as exposure bias is likely what causes many of the discrepancies.
41
+
42
+ ### How can we simplify this?
43
+
44
+ Unfortunately, explicitly predicting ahead for many steps is very expensive, and considering full model generations requires a slow generation process.
45
+
46
+ An obvious option is to simply randomly corrupt tokens in the sequence.
47
+ The model must keep an internal estimate of what the corrupted tokens ought to be in order to predict the token after them, forcing the model to think ahead.
48
+
49
+ The most obvious ways to do this are to randomly replace input tokens with a special `[mask]` token, or to randomly replace input tokens with other random tokens.
50
+ These approaches were tried in [Masked Thought](https://arxiv.org/abs/2403.02178), albeit with somewhat different motivations.
51
+
52
+ However, these approaches have a problem: Models can detect when a token is `[mask]` or is highly unlikely, so the model may only learn to think ahead when the corruptions are present.
53
+
54
+ To avoid this issue, we can run the model twice - once with a masked sequence, and once on the full sequence.
55
+ Then, we penalize deviations between these two runs, which forces the model to act the same regardless of if the `[mask]` token is present or not.
56
+
57
+ This approach was initially introduced with [R-TeaFor](https://aclanthology.org/2022.emnlp-main.423/) for abstractive summarization, but can be easily applied to standard generation tasks too.
58
+
59
+ ### ReMask and ReMask-CoT:
60
+
61
+ ReMask applies an approach similar to R-TeaFor to typical chat/instruction tuning.
62
+
63
+ Consider the following chat interaction:
64
+
65
+ > User: What is 1+1?
66
+ > Assistant: **1+1=2**
67
+ > **User:**
68
+
69
+ The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
70
+
71
+ We then compute a distance loss `D(p_masked, p_full)` between the two predictions. This approach resembles self-distillation, and MSE tends to perform better than KL Divergence for distillation, along with being easier to tune, so I went with MSE.
72
+
73
+ Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
74
+ ```
75
+ loss = CE(p_masked, labels) + CE(p_full, labels) + weight*D(p_masked, p_full)
76
+ ```
77
+
78
+ ***ReMask-CoT:***
79
+
80
+ For CoT tasks where the reasoning is explicitly separated from the answer, we can add some further improvements.
81
+
82
+ First, note that CoT rationales are noisy -- there are many correct rationales which might lead to the same correct answer, and rationales are impacted by things like writing style which don't matter for the actual correctness of the reasoning.
83
+
84
+ Keeping this in mind:
85
+
86
+ - We also randomly mask a small portion of the labels of the rationale, but not the answer, such that an accurate answer is more important than a rationale that is word-for-word identical to the annotated rationale.
87
+ - The exact answer is always important and is always a few tokens. Hence, we do not mask the labels or input tokens for the answer value.
88
+ - Rarely, we ignore the rationale labels entirely, such that the model is only pushed to learn what leads to the best answer.
89
+
90
+ ## Results