eubinecto commited on
Commit
927768a
1 Parent(s): c7b4a17

a couple of explore scripts for: adding two special tokens (<idiom>, </idiom>)

Browse files
explore/explore_bart_for_conditional_generation.py DELETED
@@ -1,10 +0,0 @@
1
-
2
- from transformers import BartTokenizer, BartForConditionalGeneration
3
-
4
-
5
- def main():
6
- pass
7
-
8
-
9
- if __name__ == '__main__':
10
- main()
 
 
 
 
 
 
 
 
 
 
 
explore/explore_bart_tokenizer_add_special_tokens.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer, BartForConditionalGeneration
2
+
3
+
4
+ def main():
5
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
6
+ bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
7
+ num_added_tokens = tokenizer.add_special_tokens({
8
+ "additional_special_tokens": ["<idiom>", "</idiom>"], # beginning and end of an idiom
9
+ })
10
+ print(num_added_tokens)
11
+ print(tokenizer.additional_special_tokens) # more special tokens are added here
12
+ # and then you should resize the embedding table of your model
13
+ print(bart.model.shared.weight.shape) # before
14
+ bart.resize_token_embeddings(len(tokenizer))
15
+ print(bart.model.shared.weight.shape) # after
16
+
17
+
18
+ if __name__ == '__main__':
19
+ main()
20
+
21
+ """
22
+ 2
23
+ ['<idiom>', '</idiom>']
24
+ torch.Size([50265, 768])
25
+ torch.Size([50267, 768]) # you can see that 2 more embedding vectors have been added here.
26
+ later, you may want to save the tokenizer after you add the idiom special tokens.
27
+ """
explore/explore_bart_tokenizer_special_tokens.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer
2
+
3
+
4
+ def main():
5
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
6
+ print(tokenizer.bos_token)
7
+ print(tokenizer.cls_token)
8
+ print(tokenizer.eos_token)
9
+ print(tokenizer.sep_token)
10
+ print(tokenizer.mask_token)
11
+ print(tokenizer.pad_token)
12
+ print(tokenizer.unk_token)
13
+
14
+
15
+ """
16
+ <s>
17
+ <s>
18
+ </s>
19
+ </s>
20
+ <mask>
21
+ <pad>
22
+ <unk>
23
+
24
+ right, so this is just like the symbols for BERT but in lowercase.
25
+ bos = cls
26
+ sep = eos
27
+ would it be okay to use <idiom> = <sep>?
28
+ no, sep implies that a sentence somehow ends.
29
+ """
30
+
31
+
32
+
33
+
34
+
35
+ if __name__ == '__main__':
36
+ main()