Agon H commited on
Commit
89305f6
1 Parent(s): bfc698f

Upload adapt_tokenizer.py

Browse files
Files changed (1) hide show
  1. adapt_tokenizer.py +41 -0
adapt_tokenizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4
+ NUM_SENTINEL_TOKENS: int = 100
5
+
6
+ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7
+ """Adds sentinel tokens and padding token (if missing).
8
+
9
+ Expands the tokenizer vocabulary to include sentinel tokens
10
+ used in mixture-of-denoiser tasks as well as a padding token.
11
+
12
+ All added tokens are added as special tokens. No tokens are
13
+ added if sentinel tokens and padding token already exist.
14
+ """
15
+ sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]
16
+ tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17
+ if tokenizer.pad_token is None:
18
+ tokenizer.add_tokens('<pad>', special_tokens=True)
19
+ tokenizer.pad_token = '<pad>'
20
+ assert tokenizer.pad_token_id is not None
21
+ sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])
22
+ _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23
+ tokenizer.sentinel_token_ids = _sentinel_token_ids
24
+
25
+ class AutoTokenizerForMOD(AutoTokenizer):
26
+ """AutoTokenizer + Adaptation for MOD.
27
+
28
+ A simple wrapper around AutoTokenizer to make instantiating
29
+ an MOD-adapted tokenizer a bit easier.
30
+
31
+ MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
32
+ a padding token, and a property to get the token ids of the
33
+ sentinel tokens.
34
+ """
35
+
36
+ @classmethod
37
+ def from_pretrained(cls, *args, **kwargs):
38
+ """See `AutoTokenizer.from_pretrained` docstring."""
39
+ tokenizer = super().from_pretrained(*args, **kwargs)
40
+ adapt_tokenizer_for_denoising(tokenizer)
41
+ return tokenizer