a couple of explore scripts for: adding two special tokens (<idiom>, </idiom>)
Browse files
explore/explore_bart_for_conditional_generation.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
|
2 |
-
from transformers import BartTokenizer, BartForConditionalGeneration
|
3 |
-
|
4 |
-
|
5 |
-
def main():
|
6 |
-
pass
|
7 |
-
|
8 |
-
|
9 |
-
if __name__ == '__main__':
|
10 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explore/explore_bart_tokenizer_add_special_tokens.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
6 |
+
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
7 |
+
num_added_tokens = tokenizer.add_special_tokens({
|
8 |
+
"additional_special_tokens": ["<idiom>", "</idiom>"], # beginning and end of an idiom
|
9 |
+
})
|
10 |
+
print(num_added_tokens)
|
11 |
+
print(tokenizer.additional_special_tokens) # more special tokens are added here
|
12 |
+
# and then you should resize the embedding table of your model
|
13 |
+
print(bart.model.shared.weight.shape) # before
|
14 |
+
bart.resize_token_embeddings(len(tokenizer))
|
15 |
+
print(bart.model.shared.weight.shape) # after
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
main()
|
20 |
+
|
21 |
+
"""
|
22 |
+
2
|
23 |
+
['<idiom>', '</idiom>']
|
24 |
+
torch.Size([50265, 768])
|
25 |
+
torch.Size([50267, 768]) # you can see that 2 more embedding vectors have been added here.
|
26 |
+
later, you may want to save the tokenizer after you add the idiom special tokens.
|
27 |
+
"""
|
explore/explore_bart_tokenizer_special_tokens.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
6 |
+
print(tokenizer.bos_token)
|
7 |
+
print(tokenizer.cls_token)
|
8 |
+
print(tokenizer.eos_token)
|
9 |
+
print(tokenizer.sep_token)
|
10 |
+
print(tokenizer.mask_token)
|
11 |
+
print(tokenizer.pad_token)
|
12 |
+
print(tokenizer.unk_token)
|
13 |
+
|
14 |
+
|
15 |
+
"""
|
16 |
+
<s>
|
17 |
+
<s>
|
18 |
+
</s>
|
19 |
+
</s>
|
20 |
+
<mask>
|
21 |
+
<pad>
|
22 |
+
<unk>
|
23 |
+
|
24 |
+
right, so this is just like the symbols for BERT but in lowercase.
|
25 |
+
bos = cls
|
26 |
+
sep = eos
|
27 |
+
would it be okay to use <idiom> = <sep>?
|
28 |
+
no, sep implies that a sentence somehow ends.
|
29 |
+
"""
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
main()
|