pere commited on
Commit
1ab611f
1 Parent(s): 925577d

new attempt

Browse files
config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_name_or_path": ".",
3
  "architectures": [
4
- "T5ForConditionalGeneration"
5
  ],
6
  "d_ff": 2048,
7
  "d_kv": 64,
@@ -20,6 +19,7 @@
20
  "num_layers": 6,
21
  "output_past": true,
22
  "pad_token_id": 0,
 
23
  "relative_attention_num_buckets": 32,
24
  "task_specific_params": {
25
  "summarization": {
@@ -50,8 +50,7 @@
50
  "prefix": "translate English to Romanian: "
51
  }
52
  },
53
- "torch_dtype": "float32",
54
- "transformers_version": "4.16.2",
55
  "use_cache": true,
56
  "vocab_size": 32128
57
  }
 
1
  {
 
2
  "architectures": [
3
+ "T5Model"
4
  ],
5
  "d_ff": 2048,
6
  "d_kv": 64,
 
19
  "num_layers": 6,
20
  "output_past": true,
21
  "pad_token_id": 0,
22
+ "relative_attention_max_distance": 128,
23
  "relative_attention_num_buckets": 32,
24
  "task_specific_params": {
25
  "summarization": {
 
50
  "prefix": "translate English to Romanian: "
51
  }
52
  },
53
+ "transformers_version": "4.18.0.dev0",
 
54
  "use_cache": true,
55
  "vocab_size": 32128
56
  }
convert_t5_checkpoint_to_flax.py DELETED
@@ -1,144 +0,0 @@
1
- import argparse
2
-
3
- from t5x import checkpoints
4
- from transformers import T5Config, FlaxT5Model
5
-
6
-
7
- def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
8
- config = T5Config.from_pretrained(config_name)
9
- flax_model = FlaxT5Model(config=config)
10
- t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
11
-
12
-
13
- #breakpoint()
14
- # Encoder
15
- for layer_index in range(config.num_layers):
16
- layer_name = f"layers_{str(layer_index)}"
17
-
18
- # Self-Attention
19
- t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
20
- t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
21
- t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
22
- t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
23
-
24
- ## Layer Normalization
25
- t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
26
-
27
- # MLP
28
- #t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
29
- #t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
30
- t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
31
- t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
32
-
33
- ## Layer Normalization
34
- t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
35
-
36
- # Assigning
37
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
38
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
39
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
40
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
41
-
42
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
43
-
44
- #flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
45
- #flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
46
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
47
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
48
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
49
-
50
- t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
51
-
52
- # Only for layer 0:
53
- t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"]
54
- x, y = t5x_encoder_rel_embedding.shape
55
-
56
- # Assigning
57
- flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding.reshape(y, x)
58
- flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
59
-
60
- # Decoder
61
- for layer_index in range(config.num_layers):
62
- layer_name = f"layers_{str(layer_index)}"
63
-
64
- # Self-Attention
65
- t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
66
- t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
67
- t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
68
- t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
69
-
70
- ## Layer Normalization
71
- t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
72
-
73
- # Encoder-Decoder-Attention
74
- t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
75
- t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
76
- t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
77
- t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
78
-
79
- ## Layer Normalization
80
- t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
81
-
82
- # MLP
83
- #t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
84
- #t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
85
- t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
86
- t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
87
-
88
- ## Layer Normalization
89
- tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
90
-
91
- #Assigning
92
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
93
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
94
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
95
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
96
-
97
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
98
-
99
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
100
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
101
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
102
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
103
-
104
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
105
-
106
- #flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
107
- #flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
108
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
109
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
110
-
111
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
112
-
113
- # Decoder Normalization
114
- tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
115
- flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
116
-
117
- # Only for layer 0:
118
- t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"]
119
- x, y = t5x_decoder_rel_embedding.shape
120
-
121
- flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"] = t5x_decoder_rel_embedding.reshape(y, x)
122
-
123
- # Token Embeddings
124
- tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
125
- flax_model.params["shared"]["embedding"] = tx5_token_embeddings
126
-
127
- flax_model.save_pretrained(flax_dump_folder_path)
128
-
129
-
130
- if __name__ == "__main__":
131
- parser = argparse.ArgumentParser()
132
- # Required parameters
133
- parser.add_argument(
134
- "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
135
- )
136
- parser.add_argument(
137
- "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
138
- )
139
- parser.add_argument(
140
- "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
141
- )
142
- args = parser.parse_args()
143
- convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
144
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
directly_from_t5x/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5Model"
4
+ ],
5
+ "d_ff": 1024,
6
+ "d_kv": 64,
7
+ "d_model": 512,
8
+ "decoder_start_token_id": 0,
9
+ "dropout_rate": 0.1,
10
+ "eos_token_id": 1,
11
+ "feed_forward_proj": "gated-gelu",
12
+ "initializer_factor": 1.0,
13
+ "is_encoder_decoder": true,
14
+ "layer_norm_epsilon": 1e-06,
15
+ "model_type": "t5",
16
+ "num_decoder_layers": 8,
17
+ "num_heads": 6,
18
+ "num_layers": 8,
19
+ "pad_token_id": 0,
20
+ "relative_attention_max_distance": 128,
21
+ "relative_attention_num_buckets": 32,
22
+ "tie_word_embeddings": false,
23
+ "tokenizer_class": "T5Tokenizer",
24
+ "transformers_version": "4.18.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 250112
27
+ }
directly_from_t5x/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0598b1d8dba89fd539c9a1776cf0b4f7b3e45c4ac8ec2f498f99c7184c57baa0
3
+ size 688485886
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1706cb09571e921f3f28634913e628ce7808dc26cca330bb3e319e27db23c9d1
3
- size 242032191
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e3f34ed523c4ea968bd610811cf91e9f68553eceebb03c7cbe8eae03be023f9
3
+ size 242032202