pere commited on
Commit
b025910
1 Parent(s): 740aab3
Files changed (2) hide show
  1. README.md +32 -0
  2. convert_t5_checkpoint_to_flax.py +144 -0
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - fr
5
+ - ro
6
+ - de
7
+ datasets:
8
+ - c4
9
+ tags:
10
+ - summarization
11
+ - translation
12
+
13
+ license: apache-2.0
14
+ ---
15
+
16
+ ## Test T5 small conversion
17
+ This is a test repo for the conversion of T5X to HuggingFace Flax.
18
+
19
+ The current model is first converted from MTF to T5X using the conversion script included in the T5X library:
20
+
21
+ ```bash
22
+ python3 -m t5x.scripts.convert_tf_checkpoint --gin_file=t5x/examples/t5/t5_1_0/small.gin --gin.convert_checkpoint.model=%MODEL --gin.convert_checkpoint.tf_checkpoint_path=\"gs://t5-data/pretrained_models/small/model.ckpt-1000000\" --gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\" --logtostderr
23
+ ```
24
+
25
+ After creating the T5X model, the model is converted to Huggingface Flax by a modified version of the script from @stefan-it (https://gist.githubusercontent.com/stefan-it/30e4998ef159f33696e377a46f699d9f/raw/c19da5d067dc9d31d0b8115a79e8626186e11daa/convert_t5x_checkpoint_to_flax.py). The modified version is included in this repo. The modification is basically that the wi_0 and wi_1 layers are combined into wi. This might be a difference between t5_1_0 and t5_1_1
26
+
27
+ ```bash
28
+ python3 convert_t5_checkpoint_to_flax.py --t5x_checkpoint_path /tmp/t5x_checkpoints/t5_small/checkpoint_1000000/ --flax_dump_folder_path /tmp/flax_dump_folder/ --config_name t5-small
29
+ ```
30
+
31
+
32
+
convert_t5_checkpoint_to_flax.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from t5x import checkpoints
4
+ from transformers import T5Config, FlaxT5Model
5
+
6
+
7
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
8
+ config = T5Config.from_pretrained(config_name)
9
+ flax_model = FlaxT5Model(config=config)
10
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
11
+
12
+
13
+ #breakpoint()
14
+ # Encoder
15
+ for layer_index in range(config.num_layers):
16
+ layer_name = f"layers_{str(layer_index)}"
17
+
18
+ # Self-Attention
19
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
20
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
21
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
22
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
23
+
24
+ ## Layer Normalization
25
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
26
+
27
+ # MLP
28
+ #t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
29
+ #t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
30
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
31
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
32
+
33
+ ## Layer Normalization
34
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
35
+
36
+ # Assigning
37
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
38
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
39
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
40
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
41
+
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
43
+
44
+ #flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
45
+ #flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
46
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
47
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
48
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
49
+
50
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
51
+
52
+ # Only for layer 0:
53
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"]
54
+ x, y = t5x_encoder_rel_embedding.shape
55
+
56
+ # Assigning
57
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding.reshape(y, x)
58
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
59
+
60
+ # Decoder
61
+ for layer_index in range(config.num_layers):
62
+ layer_name = f"layers_{str(layer_index)}"
63
+
64
+ # Self-Attention
65
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
66
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
67
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
68
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
69
+
70
+ ## Layer Normalization
71
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
72
+
73
+ # Encoder-Decoder-Attention
74
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
75
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
76
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
77
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
78
+
79
+ ## Layer Normalization
80
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
81
+
82
+ # MLP
83
+ #t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
84
+ #t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
85
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
86
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
87
+
88
+ ## Layer Normalization
89
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
90
+
91
+ #Assigning
92
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
93
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
94
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
95
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
96
+
97
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
98
+
99
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
100
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
103
+
104
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
105
+
106
+ #flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
107
+ #flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
110
+
111
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
112
+
113
+ # Decoder Normalization
114
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
115
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
116
+
117
+ # Only for layer 0:
118
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"]
119
+ x, y = t5x_decoder_rel_embedding.shape
120
+
121
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"] = t5x_decoder_rel_embedding.reshape(y, x)
122
+
123
+ # Token Embeddings
124
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
125
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
126
+
127
+ flax_model.save_pretrained(flax_dump_folder_path)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ # Required parameters
133
+ parser.add_argument(
134
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
135
+ )
136
+ parser.add_argument(
137
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
138
+ )
139
+ parser.add_argument(
140
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
141
+ )
142
+ args = parser.parse_args()
143
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
144
+