stas commited on
Commit
2bd01f7
1 Parent(s): 37a0fb2

tiny mt5 random model

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MT5ForConditionalGeneration"
4
+ ],
5
+ "d_ff": 256,
6
+ "d_kv": 8,
7
+ "d_model": 64,
8
+ "decoder_start_token_id": 0,
9
+ "dropout_rate": 0.1,
10
+ "eos_token_id": 1,
11
+ "feed_forward_proj": "gated-gelu",
12
+ "initializer_factor": 1.0,
13
+ "is_encoder_decoder": true,
14
+ "layer_norm_epsilon": 1e-06,
15
+ "model_type": "mt5",
16
+ "num_decoder_layers": 8,
17
+ "num_heads": 4,
18
+ "num_layers": 8,
19
+ "pad_token_id": 0,
20
+ "relative_attention_num_buckets": 32,
21
+ "tie_word_embeddings": false,
22
+ "tokenizer_class": "T5Tokenizer",
23
+ "transformers_version": "4.6.0.dev0",
24
+ "use_cache": true,
25
+ "vocab_size": 5100
26
+ }
mt5-make-tiny-model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # This script creates a smallish random model, with a few layers to test things like MP/PP, where
18
+ # tiny and tiner models are too too small
19
+ #
20
+ # It will be used then as "stas/mt5-tiny-random"
21
+
22
+ # To build:
23
+ # 1. clone sentencepiece into this dir
24
+ # git clone https://github.com/google/sentencepiece
25
+ #
26
+ # 2. run this script
27
+
28
+ from pathlib import Path
29
+ import json
30
+ import tempfile
31
+
32
+ from transformers import MT5Tokenizer, MT5TokenizerFast, MT5Config, MT5ForConditionalGeneration
33
+ from transformers.models.t5.tokenization_t5 import VOCAB_FILES_NAMES
34
+
35
+ mname_from = "google/mt5-small"
36
+ mname_very_small = "mt5-tiny-random"
37
+
38
+ tokenizer = MT5Tokenizer.from_pretrained(mname_from)
39
+ config = MT5Config.from_pretrained(mname_from)
40
+ #tokenizer_fast = MT5TokenizerFast.from_pretrained(mname_from)
41
+
42
+ # Shrink the vocab of mt5-small
43
+ import sys
44
+ # HACK: need the sentencepiece source to get sentencepiece_model_pb2, as it doesn't get installed
45
+ sys.path.append("./sentencepiece/python/src/sentencepiece")
46
+ import sentencepiece_model_pb2 as model
47
+
48
+ tmp_dir = "/tmp/mt5-small"
49
+ tokenizer.save_pretrained(tmp_dir)
50
+ file = tmp_dir + "/spiece.model"
51
+ with open(file, 'rb') as f: data = f.read()
52
+
53
+ # adapted from https://blog.ceshine.net/post/trim-down-sentencepiece-vocabulary/
54
+ m = model.ModelProto()
55
+ m.ParseFromString(data)
56
+
57
+ keep_items = 5000
58
+
59
+ print("Shrinking vocab")
60
+ print(f"original dict {len(m.pieces)}")
61
+ for i in range(len(m.pieces)-keep_items): _ = m.pieces.pop()
62
+ print(f"new dict {len(m.pieces)}")
63
+
64
+ with open(tmp_dir + "/spiece-short.model", 'wb') as f:
65
+ f.write(m.SerializeToString())
66
+
67
+ tokenizer = MT5Tokenizer(vocab_file=tmp_dir + "/spiece-short.model")
68
+
69
+ config.update(dict(
70
+ vocab_size=keep_items+12,
71
+ d_model=64,
72
+ d_ff=256,
73
+ d_kv=8,
74
+ num_layers=8,
75
+ num_decoder_layers=8,
76
+ num_heads=4,
77
+ relative_attention_num_buckets=32,
78
+ ))
79
+ print("new config", config)
80
+
81
+ very_small_model = MT5ForConditionalGeneration(config)
82
+ print(f"num of params {very_small_model.num_parameters()}")
83
+ very_small_model.resize_token_embeddings(len(tokenizer))
84
+
85
+ # Test
86
+ src_texts = ["A long paragraph for summarization.", "Another paragraph for summarization."]
87
+ tgt_texts = ["Summary of the text.", "Another summary."]
88
+
89
+ batch = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, return_tensors="pt")
90
+ outputs = very_small_model(**batch)
91
+
92
+ print("test output:", len(outputs.logits[0]))
93
+
94
+ # Save
95
+ very_small_model.half() # makes it smaller
96
+ very_small_model.save_pretrained(mname_very_small)
97
+ config.save_pretrained(mname_very_small)
98
+ tokenizer.save_pretrained(mname_very_small)
99
+ #tokenizer_fast.save_pretrained(mname_very_small)
100
+
101
+ print(f"Generated {mname_very_small}")
102
+
103
+ # Upload
104
+ # transformers-cli repo create mt5-tiny-random
105
+ # clone and add files
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e47ef1ea6344ab7c8e41bb0001ff32ded49a2c17c6539994e242e30acb15fd85
3
+ size 3342941
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6859b5aa593827a9593cc7c313eb5bc86444a971387dae19ae4a3d2ba389bbae
3
+ size 312867
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}