Milan Straka commited on
Commit
a75ff23
·
1 Parent(s): 20ab64f

Add scripts used to generate v1.1.

Browse files
scripts-for-generating-v1.1/pt_fix.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import torch
4
+ import transformers
5
+
6
+ if __name__ == "__main__":
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("input_path", type=str, help="Input directory")
9
+ parser.add_argument("output_path", type=str, help="Output directory")
10
+ args = parser.parse_args()
11
+
12
+ robeczech = transformers.AutoModelForMaskedLM.from_pretrained(args.input_path, add_pooling_layer=True)
13
+
14
+ unk_id, mask_id, new_vocab = 3, 51960, 51997
15
+
16
+ assert robeczech.roberta.embeddings.word_embeddings.weight is robeczech.lm_head.decoder.weight
17
+ assert robeczech.lm_head.bias is robeczech.lm_head.decoder.bias
18
+ for weight in [robeczech.roberta.embeddings.word_embeddings.weight, robeczech.lm_head.bias]: #, robeczech.lm_head.decoder.weight]:
19
+ original = weight.data
20
+ assert original.shape[0] == mask_id + 1, original.shape
21
+ weight.data = torch.zeros((new_vocab,) + original.shape[1:], dtype=original.dtype)
22
+ weight.data[:mask_id + 1] = original
23
+ for new_unk in [mask_id - 1] + list(range(mask_id + 1, new_vocab)):
24
+ weight.data[new_unk] = original[unk_id]
25
+
26
+ robeczech.save_pretrained(args.output_path)
27
+ robeczech.save_pretrained(args.output_path, safe_serialization=False)
scripts-for-generating-v1.1/tf_fix.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import transformers
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("input_path", type=str, help="Input directory")
8
+ parser.add_argument("output_path", type=str, help="Output directory")
9
+ args = parser.parse_args()
10
+
11
+ robeczech = transformers.TFAutoModelWithLMHead.from_pretrained(args.input_path, from_pt=True)
12
+ robeczech.save_pretrained(args.output_path)
scripts-for-generating-v1.1/tokenizer_fix.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ import os
5
+
6
+ import transformers
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("input_path", type=str, help="Input directory")
11
+ parser.add_argument("output_path", type=str, help="Output directory")
12
+ args = parser.parse_args()
13
+
14
+ # Fix vocab.json
15
+ def fix_vocab(vocab):
16
+ mask_id = 51960
17
+ unused = mask_id + 1
18
+ remapped = []
19
+ fixed_vocab = {}
20
+ for key, value in vocab.items():
21
+ if value == 3 and key != "[UNK]":
22
+ if key == "ĠĊ":
23
+ fixed_vocab[key] = mask_id - 1
24
+ else:
25
+ remapped.append((key, unused))
26
+ unused += 1
27
+ else:
28
+ fixed_vocab[key] = value
29
+
30
+ for key, value in remapped:
31
+ fixed_vocab[key] = value
32
+
33
+ return fixed_vocab
34
+
35
+ with open(os.path.join(args.input_path, "vocab.json"), "r", encoding="utf-8") as vocab_file:
36
+ vocab = json.load(vocab_file)
37
+
38
+ fixed_vocab = fix_vocab(vocab)
39
+
40
+ with open(os.path.join(args.output_path, "vocab.json"), "w", encoding="utf-8") as vocab_file:
41
+ json.dump(fixed_vocab, vocab_file, ensure_ascii=False, indent=None)
42
+ print(file=vocab_file)
43
+
44
+ # Regenerate tokenizer.json
45
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.output_path)
46
+ tokenizer._tokenizer.save(os.path.join(args.output_path, "tokenizer.json"))