mebubo commited on
Commit
bb5d096
·
1 Parent(s): 0d14d7b

mv BLOG.md README.md

Browse files
Files changed (2) hide show
  1. BLOG.md +0 -186
  2. README.md +186 -0
BLOG.md DELETED
@@ -1,186 +0,0 @@
1
- # GPTed blog post part 1
2
-
3
- What I want to cover:
4
- - The original blog post
5
- - Improvements that I wanted to make:
6
- - In addition to highlighting low-probability words, show replacement suggestions that are more likely
7
- - Operate at the level of whole words, not tokens
8
- - Justification for using a local model
9
- - Limitations of the logprobs returned by the APIs
10
- - Main parts of the project
11
- - Combining tokens into words to get the probabilities of whole words
12
- - The batched multi-token expansion with probability budget
13
- - Testable abstract implementation
14
-
15
-
16
- This post describes my attempt to build an improved version of GPTed from https://vgel.me/posts/gpted-launch/ and what I learned from it.
17
-
18
- Here is what has been done in the original GPTed:
19
- - Use logprobs returned by the OpenAI API (in particular, the /v1/completions legacy api https://platform.openai.com/docs/api-reference/completions) for tokens _in the existing text_ (as opposed to generated text) to detect the tokens the model is surprised by
20
- - Provide a basic text editing UI that has a mode in which the tokens with a logprob below a given threshold are highlighted. Not all highlighted tokens are necessarily a mistake, but the idea is that it may be worth checking that a low-probability token is indeed intended.
21
-
22
- Here are the improvements that I wanted to make:
23
- - Operate at the word level, instead of token level, to compute the log prob of whole words even if they are mutli-token, and to highlight whole words
24
- - Propose replacement words for the highlighted words
25
- - Specifically, words with probability higher than the flagging threshold
26
-
27
- ### On logprobs in OpenAI API
28
-
29
- The original GPTed project relied on the 2 features in the legacy OpenAI /v1/completions API:
30
-
31
- > logprobs: Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. The maximum value for `logprobs` is 5.
32
-
33
- > echo: Echo back the prompt in addition to the completion
34
-
35
- The echo parameter doesn't exist anymore in the modern chat completions API /v1/chat/completions, making it impossible to get logprobs for an existing text (as opposed to generated text). The legacy completions API is not available for modern models like GPT4 (FIXME verify this claim).
36
-
37
- Also, the limit of 5 for the number of logprobs is also quite limiting: there may well be more than 5 tokens above the threshold, and I would like to be able to take all of them into account.
38
-
39
- Also, the case of multi-token words meant that it would be convenient to use batching, which is not available over the OpenAI API.
40
- For the above 3 reasons, I decided to switch to using local models.
41
-
42
- ### Local models with huggingface transformers
43
- To run inference locally and get the logits I used huggingface transformers. As model, I used Llama 3.2 1B, because it runs fast enough on a CPU to enable local development on my laptop.
44
- The basic usage to get logits for every token in an input is straightforward:
45
-
46
- ```python
47
- from transformers import AutoTokenizer, AutoModelForCausalLM
48
-
49
- model_name = "unsloth/Llama-3.2-1B"
50
- tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- model = AutoModelForCausalLM.from_pretrained(model_name)
52
-
53
- input_text = "Hello world!"
54
- inputs = tokenizer(input_text, return_tensors="pt")
55
-
56
- with torch.no_grad():
57
- outputs = model(**inputs)
58
-
59
- logits = outputs.logits # Shape: [batch_size, sequence_length, vocab_size]
60
- ```
61
-
62
- Here is how I compute the logprob for every token in the input:
63
-
64
- ```python
65
- def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
66
- input_ids = inputs["input_ids"]
67
- attention_mask = inputs["attention_mask"]
68
- with torch.no_grad():
69
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
70
- # B x T x V
71
- logits: torch.Tensor = outputs.logits[:, :-1, :]
72
- # B x T x V
73
- log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
74
- # T - 1
75
- tokens: torch.Tensor = input_ids[0][1:]
76
- # T - 1
77
- token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), tokens]
78
- return list(zip(tokens.tolist(), token_log_probs.tolist()))
79
- ```
80
-
81
- Explanation:
82
- - we drop the logits for the last token, because it corresponds to the probability of the next token (which we don't have)
83
- - we compute the softmax over the last dimension (vocab size), to obtain probability distribution over all tokens
84
- - we drop the first token because it is a start-of-sequence token
85
- - `log_probs[0, range(log_probs.shape[1]), tokens]` indexes into log_probs such as to extract
86
- - at position 0 (probability distribution for the first token after the start-of-sequence token) - the logprob value corresponding to the actual first token
87
- - at position 1 (probability distribution for the second token after the start-of-sequence token) - the logprob value corresponding to the actual second token
88
- - etc.
89
-
90
-
91
- Here is how I handled combining tokens into words.
92
-
93
- I wrote a very generic `combine` function, that takes a list of values and a function that tells it how to combine two adjacent values into a single value. If the function returns `None`, the values are not combined.
94
-
95
- Thanks to the fact that it is generic, it is very easy to test:
96
-
97
-
98
- ```python
99
- def test_add_if_even():
100
- def add_if_even(x: int, y: int) -> int | None:
101
- if (x + y) % 2 == 0:
102
- return x + y
103
- return None
104
-
105
- assert combine([1, 3, 1, 4], add_if_even) == [4, 1, 4]
106
- assert combine([1, 3, 2, 4], add_if_even) == [10]
107
- ```
108
-
109
- Applying this function to the problem of combining tokens into words is just a matter of writing the correct `combine_fn`:
110
-
111
- ```
112
- @dataclass
113
- class Tok:
114
- index: int
115
- ids: list[int]
116
- str: str
117
- logprob: float
118
-
119
- def is_beginning_of_word(s: str) -> bool:
120
- return (s[0] == " " and s[1:].isalpha()) or s.isalpha()
121
-
122
- def is_continuation_of_word(s: str) -> bool:
123
- return s.isalpha()
124
-
125
- def merge_tokens(a: Tok, b: Tok) -> Tok | None:
126
- if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
127
- return Tok(a.index, a.ids + b.ids, a.str + b.str, a.logprob + b.logprob)
128
- return None
129
- ```
130
-
131
- This handles nicely the computation of combined logprob for words, and allows me to highlight whole words based on a threshold.
132
-
133
- The next step was to produce suggestions for replacement words.
134
-
135
- Here is how I do it:
136
-
137
- Extract the contexts (lists of token prefixes -- all tokens up to the word in question) for each flagged word:
138
-
139
- ```python
140
- contexts = [word.context for _, word in low_prob_words]
141
- ```
142
-
143
- Create a `Series` for each context (a series has a budget), and bundle them into a `Batch`:
144
- ```python
145
- series = []
146
- for i, x in enumerate(contexts):
147
- series.append(Series(id=i, tokens=x, budget=5.0))
148
-
149
- batch = Batch(items=series)
150
- ```
151
-
152
- Stopping criterion decides when to stop expanding a series
153
-
154
- ```python
155
- stopping_criterion = create_stopping_criterion_llm(tokenizer)
156
- ```
157
-
158
- In my case, I stop when the budget is exhausted, and I also stop if the expansion reached a word boundary (I'm only interested in single-word replacements).
159
-
160
- Given the batch and the stopping criterion, we can call the expander:
161
- ```python
162
- expander = ExpanderOneBatchLLM(model, tokenizer)
163
- expanded = expand(batch, expander, stopping_criterion)
164
- ```
165
-
166
- The `expand` logic is the most complex part of the project, and in order to make it testable, I made it generic, with only a small part that is llm-specific.
167
-
168
- Here is what the tests look like:
169
-
170
- ```python
171
- def test_expander_zero_budget():
172
- s = Series(id=0, tokens=[1], budget=0.0)
173
- expanded = expander.expand(Batch(items=[s]))
174
- expected = ExpansionOneResultBatch(
175
- items=[ExpansionOneResult(series=s, expansions=[
176
- Expansion(token=21, cost=-1.0),
177
- Expansion(token=22, cost=-1.0),
178
- ])]
179
- )
180
- assert expected == expanded
181
- ```
182
-
183
- They are based on a non-llm expander based on a hardcoded list of possible expansions, so they are very easy to write, straightforward to interpret, and run very fast.
184
-
185
- ### Other potential possibilities / ideas
186
- - Instead of using a local model, investigate using an API of a provider that exposes logprobs e.g. replicate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPTed blog post part 1
2
+
3
+ What I want to cover:
4
+ - The original blog post
5
+ - Improvements that I wanted to make:
6
+ - In addition to highlighting low-probability words, show replacement suggestions that are more likely
7
+ - Operate at the level of whole words, not tokens
8
+ - Justification for using a local model
9
+ - Limitations of the logprobs returned by the APIs
10
+ - Main parts of the project
11
+ - Combining tokens into words to get the probabilities of whole words
12
+ - The batched multi-token expansion with probability budget
13
+ - Testable abstract implementation
14
+
15
+
16
+ This post describes my attempt to build an improved version of GPTed from https://vgel.me/posts/gpted-launch/ and what I learned from it.
17
+
18
+ Here is what has been done in the original GPTed:
19
+ - Use logprobs returned by the OpenAI API (in particular, the /v1/completions legacy api https://platform.openai.com/docs/api-reference/completions) for tokens _in the existing text_ (as opposed to generated text) to detect the tokens the model is surprised by
20
+ - Provide a basic text editing UI that has a mode in which the tokens with a logprob below a given threshold are highlighted. Not all highlighted tokens are necessarily a mistake, but the idea is that it may be worth checking that a low-probability token is indeed intended.
21
+
22
+ Here are the improvements that I wanted to make:
23
+ - Operate at the word level, instead of token level, to compute the log prob of whole words even if they are mutli-token, and to highlight whole words
24
+ - Propose replacement words for the highlighted words
25
+ - Specifically, words with probability higher than the flagging threshold
26
+
27
+ ### On logprobs in OpenAI API
28
+
29
+ The original GPTed project relied on the 2 features in the legacy OpenAI /v1/completions API:
30
+
31
+ > logprobs: Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. The maximum value for `logprobs` is 5.
32
+
33
+ > echo: Echo back the prompt in addition to the completion
34
+
35
+ The echo parameter doesn't exist anymore in the modern chat completions API /v1/chat/completions, making it impossible to get logprobs for an existing text (as opposed to generated text). The legacy completions API is not available for modern models like GPT4 (FIXME verify this claim).
36
+
37
+ Also, the limit of 5 for the number of logprobs is also quite limiting: there may well be more than 5 tokens above the threshold, and I would like to be able to take all of them into account.
38
+
39
+ Also, the case of multi-token words meant that it would be convenient to use batching, which is not available over the OpenAI API.
40
+ For the above 3 reasons, I decided to switch to using local models.
41
+
42
+ ### Local models with huggingface transformers
43
+ To run inference locally and get the logits I used huggingface transformers. As model, I used Llama 3.2 1B, because it runs fast enough on a CPU to enable local development on my laptop.
44
+ The basic usage to get logits for every token in an input is straightforward:
45
+
46
+ ```python
47
+ from transformers import AutoTokenizer, AutoModelForCausalLM
48
+
49
+ model_name = "unsloth/Llama-3.2-1B"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ model = AutoModelForCausalLM.from_pretrained(model_name)
52
+
53
+ input_text = "Hello world!"
54
+ inputs = tokenizer(input_text, return_tensors="pt")
55
+
56
+ with torch.no_grad():
57
+ outputs = model(**inputs)
58
+
59
+ logits = outputs.logits # Shape: [batch_size, sequence_length, vocab_size]
60
+ ```
61
+
62
+ Here is how I compute the logprob for every token in the input:
63
+
64
+ ```python
65
+ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
66
+ input_ids = inputs["input_ids"]
67
+ attention_mask = inputs["attention_mask"]
68
+ with torch.no_grad():
69
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
70
+ # B x T x V
71
+ logits: torch.Tensor = outputs.logits[:, :-1, :]
72
+ # B x T x V
73
+ log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
74
+ # T - 1
75
+ tokens: torch.Tensor = input_ids[0][1:]
76
+ # T - 1
77
+ token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), tokens]
78
+ return list(zip(tokens.tolist(), token_log_probs.tolist()))
79
+ ```
80
+
81
+ Explanation:
82
+ - we drop the logits for the last token, because it corresponds to the probability of the next token (which we don't have)
83
+ - we compute the softmax over the last dimension (vocab size), to obtain probability distribution over all tokens
84
+ - we drop the first token because it is a start-of-sequence token
85
+ - `log_probs[0, range(log_probs.shape[1]), tokens]` indexes into log_probs such as to extract
86
+ - at position 0 (probability distribution for the first token after the start-of-sequence token) - the logprob value corresponding to the actual first token
87
+ - at position 1 (probability distribution for the second token after the start-of-sequence token) - the logprob value corresponding to the actual second token
88
+ - etc.
89
+
90
+
91
+ Here is how I handled combining tokens into words.
92
+
93
+ I wrote a very generic `combine` function, that takes a list of values and a function that tells it how to combine two adjacent values into a single value. If the function returns `None`, the values are not combined.
94
+
95
+ Thanks to the fact that it is generic, it is very easy to test:
96
+
97
+
98
+ ```python
99
+ def test_add_if_even():
100
+ def add_if_even(x: int, y: int) -> int | None:
101
+ if (x + y) % 2 == 0:
102
+ return x + y
103
+ return None
104
+
105
+ assert combine([1, 3, 1, 4], add_if_even) == [4, 1, 4]
106
+ assert combine([1, 3, 2, 4], add_if_even) == [10]
107
+ ```
108
+
109
+ Applying this function to the problem of combining tokens into words is just a matter of writing the correct `combine_fn`:
110
+
111
+ ```
112
+ @dataclass
113
+ class Tok:
114
+ index: int
115
+ ids: list[int]
116
+ str: str
117
+ logprob: float
118
+
119
+ def is_beginning_of_word(s: str) -> bool:
120
+ return (s[0] == " " and s[1:].isalpha()) or s.isalpha()
121
+
122
+ def is_continuation_of_word(s: str) -> bool:
123
+ return s.isalpha()
124
+
125
+ def merge_tokens(a: Tok, b: Tok) -> Tok | None:
126
+ if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
127
+ return Tok(a.index, a.ids + b.ids, a.str + b.str, a.logprob + b.logprob)
128
+ return None
129
+ ```
130
+
131
+ This handles nicely the computation of combined logprob for words, and allows me to highlight whole words based on a threshold.
132
+
133
+ The next step was to produce suggestions for replacement words.
134
+
135
+ Here is how I do it:
136
+
137
+ Extract the contexts (lists of token prefixes -- all tokens up to the word in question) for each flagged word:
138
+
139
+ ```python
140
+ contexts = [word.context for _, word in low_prob_words]
141
+ ```
142
+
143
+ Create a `Series` for each context (a series has a budget), and bundle them into a `Batch`:
144
+ ```python
145
+ series = []
146
+ for i, x in enumerate(contexts):
147
+ series.append(Series(id=i, tokens=x, budget=5.0))
148
+
149
+ batch = Batch(items=series)
150
+ ```
151
+
152
+ Stopping criterion decides when to stop expanding a series
153
+
154
+ ```python
155
+ stopping_criterion = create_stopping_criterion_llm(tokenizer)
156
+ ```
157
+
158
+ In my case, I stop when the budget is exhausted, and I also stop if the expansion reached a word boundary (I'm only interested in single-word replacements).
159
+
160
+ Given the batch and the stopping criterion, we can call the expander:
161
+ ```python
162
+ expander = ExpanderOneBatchLLM(model, tokenizer)
163
+ expanded = expand(batch, expander, stopping_criterion)
164
+ ```
165
+
166
+ The `expand` logic is the most complex part of the project, and in order to make it testable, I made it generic, with only a small part that is llm-specific.
167
+
168
+ Here is what the tests look like:
169
+
170
+ ```python
171
+ def test_expander_zero_budget():
172
+ s = Series(id=0, tokens=[1], budget=0.0)
173
+ expanded = expander.expand(Batch(items=[s]))
174
+ expected = ExpansionOneResultBatch(
175
+ items=[ExpansionOneResult(series=s, expansions=[
176
+ Expansion(token=21, cost=-1.0),
177
+ Expansion(token=22, cost=-1.0),
178
+ ])]
179
+ )
180
+ assert expected == expanded
181
+ ```
182
+
183
+ They are based on a non-llm expander based on a hardcoded list of possible expansions, so they are very easy to write, straightforward to interpret, and run very fast.
184
+
185
+ ### Other potential possibilities / ideas
186
+ - Instead of using a local model, investigate using an API of a provider that exposes logprobs e.g. replicate