Files changed (3) hide show
  1. Yi_logo.svg +7 -0
  2. convert_llama_megatron_hf.py +382 -0
  3. m-a-p.png +0 -0
Yi_logo.svg ADDED
convert_llama_megatron_hf.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ from transformers import LlamaConfig, LlamaForCausalLM
7
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
8
+ import accelerate
9
+
10
+ transformer_layer_name_list = {
11
+ "input_norm": [
12
+ "input_norm.weight",
13
+ "self_attention.norm_qkv.layer_norm_weight",
14
+ ],
15
+ "query_key_value": [
16
+ "self_attention.query_key_value.weight",
17
+ "self_attention.norm_qkv.weight",
18
+ ],
19
+ "query": ["self_attention.query.weight"],
20
+ "key_value": ["self_attention.key_value.weight"],
21
+ "o_proj": ["self_attention.dense.weight", "self_attention.proj.weight"],
22
+ "mlp_gate_up": ["mlp.dense_h_to_4h.weight", "norm_mlp.fc1_weight"],
23
+ "mlp_down": ["mlp.dense_4h_to_h.weight", "norm_mlp.fc2_weight"],
24
+ "post_attention_norm": [
25
+ "post_attention_norm.weight",
26
+ "norm_mlp.layer_norm_weight",
27
+ ],
28
+ }
29
+
30
+
31
+ def recursive_print(name, val, spaces=0):
32
+ # Format the message.
33
+ if name is None:
34
+ msg = None
35
+ else:
36
+ fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
37
+ msg = fmt.format(name)
38
+
39
+ # Print and recurse (if needed).
40
+ if isinstance(val, dict):
41
+ if msg is not None:
42
+ print(msg)
43
+ for k in val.keys():
44
+ recursive_print(k, val[k], spaces + 2)
45
+ elif isinstance(val, torch.Tensor):
46
+ print(msg, ":", val.size())
47
+ else:
48
+ print(msg, ":", val)
49
+
50
+
51
+ def get(dicts, key):
52
+ return [dict[key] for dict in dicts]
53
+
54
+
55
+ def check_get(dicts, prefix, key_list):
56
+ return [
57
+ dict[prefix + key] for dict in dicts for key in key_list if prefix + key in dict
58
+ ]
59
+
60
+
61
+ def check_assign(encoder, this_layer_index, this_encoder, layer_index, key_list):
62
+ for key in key_list:
63
+ full_key = f"layers.{layer_index}." + key
64
+ if full_key in this_encoder:
65
+ encoder[f"layers.{this_layer_index}." + key] = this_encoder[full_key]
66
+ break
67
+ return encoder
68
+
69
+
70
+ def merge_col(tensors):
71
+ return torch.cat(
72
+ [
73
+ tensor["weight"] if type(tensor) is OrderedDict else tensor
74
+ for tensor in tensors
75
+ ],
76
+ dim=0,
77
+ )
78
+
79
+
80
+ def merge_row(tensors):
81
+ return torch.cat(
82
+ [
83
+ tensor["weight"] if type(tensor) is OrderedDict else tensor
84
+ for tensor in tensors
85
+ ],
86
+ dim=1,
87
+ )
88
+
89
+
90
+ def convert_megatron_checkpoint(hf_model, state_dicts, model_config: LlamaConfig):
91
+ # The model.
92
+ models = get(state_dicts, "model")
93
+
94
+ # The language model.
95
+ lms = get(models, "language_model")
96
+
97
+ # The embeddings.
98
+ embeddings = get(lms, "embedding")
99
+
100
+ # The word embeddings.
101
+ word_embeddings = get(embeddings, "word_embeddings")
102
+
103
+ # Truncate the embedding table to vocab_size rows.
104
+ merged_padded_word_embeddings = merge_col(word_embeddings)
105
+ merged_word_embeddings = merged_padded_word_embeddings[: model_config.vocab_size, :]
106
+ hf_model.model.embed_tokens.load_state_dict(
107
+ {"weight": merged_word_embeddings}, strict=True
108
+ )
109
+
110
+ # The transformer.
111
+ transformers = get(lms, "encoder")
112
+
113
+ for i in range(model_config.num_hidden_layers):
114
+ print("Converting layer", i)
115
+ prefix = f"layers.{i}."
116
+ layer: LlamaDecoderLayer = hf_model.model.layers[i]
117
+
118
+ layer.input_layernorm.load_state_dict(
119
+ {
120
+ "weight": check_get(
121
+ transformers, prefix, transformer_layer_name_list["input_norm"]
122
+ )[0]
123
+ },
124
+ strict=True,
125
+ )
126
+
127
+ hidden_size = model_config.hidden_size
128
+ inter_size = model_config.intermediate_size
129
+ num_heads = model_config.num_attention_heads
130
+ kv_heads = model_config.num_key_value_heads
131
+ kv_hidden_size = hidden_size // num_heads * kv_heads
132
+ if num_heads == kv_heads:
133
+ qkv = merge_col(
134
+ check_get(
135
+ transformers, prefix, transformer_layer_name_list["query_key_value"]
136
+ )
137
+ )
138
+ qkv = qkv.view(num_heads, 3, hidden_size // num_heads, hidden_size)
139
+ q, k, v = torch.chunk(qkv, 3, dim=1)
140
+ q, k, v = (
141
+ q.reshape(hidden_size, hidden_size),
142
+ k.reshape(hidden_size, hidden_size),
143
+ v.reshape(hidden_size, hidden_size),
144
+ )
145
+ else:
146
+ qkv = merge_col(
147
+ check_get(
148
+ transformers, prefix, transformer_layer_name_list["query_key_value"]
149
+ )
150
+ )
151
+
152
+ num_queries_per_key_value = num_heads // kv_heads
153
+ qkv = qkv.view(
154
+ kv_heads,
155
+ num_queries_per_key_value + 2,
156
+ hidden_size // num_heads,
157
+ hidden_size,
158
+ )
159
+ q, k, v = torch.split(qkv, [num_queries_per_key_value, 1, 1], dim=1)
160
+
161
+
162
+ q, k, v = (
163
+ q.reshape(hidden_size, hidden_size),
164
+ k.reshape(kv_hidden_size, hidden_size),
165
+ v.reshape(kv_hidden_size, hidden_size),
166
+ )
167
+
168
+ layer.self_attn.q_proj.load_state_dict({"weight": q}, strict=True)
169
+ layer.self_attn.k_proj.load_state_dict({"weight": k}, strict=True)
170
+ layer.self_attn.v_proj.load_state_dict({"weight": v}, strict=True)
171
+
172
+ layer.self_attn.o_proj.load_state_dict(
173
+ {
174
+ "weight": merge_row(
175
+ check_get(
176
+ transformers, prefix, transformer_layer_name_list["o_proj"]
177
+ )
178
+ )
179
+ },
180
+ strict=True,
181
+ )
182
+
183
+ gate, up = (
184
+ merge_col(
185
+ check_get(
186
+ transformers, prefix, transformer_layer_name_list["mlp_gate_up"]
187
+ )
188
+ )
189
+ .view(len(state_dicts), 2, -1, hidden_size)
190
+ .chunk(2, dim=1)
191
+ )
192
+ gate, up = gate.reshape(inter_size, hidden_size), up.reshape(
193
+ inter_size, hidden_size
194
+ )
195
+ layer.mlp.gate_proj.load_state_dict({"weight": gate}, strict=True)
196
+ layer.mlp.up_proj.load_state_dict({"weight": up}, strict=True)
197
+ layer.mlp.down_proj.load_state_dict(
198
+ {
199
+ "weight": merge_row(
200
+ check_get(
201
+ transformers, prefix, transformer_layer_name_list["mlp_down"]
202
+ )
203
+ )
204
+ },
205
+ strict=True,
206
+ )
207
+
208
+ layer.post_attention_layernorm.load_state_dict(
209
+ {
210
+ "weight": check_get(
211
+ transformers,
212
+ prefix,
213
+ transformer_layer_name_list["post_attention_norm"],
214
+ )[0]
215
+ },
216
+ strict=True,
217
+ )
218
+
219
+ # The final norm.
220
+ hf_model.model.norm.load_state_dict(
221
+ {"weight": transformers[0]["final_norm.weight"]}, strict=True
222
+ )
223
+
224
+ # For LM head, transformers' wants the matrix to weight embeddings.
225
+ output_layers = get(lms, "output_layer")
226
+ merged_padded_output_layers = merge_col(output_layers)
227
+ merged_output_layers = merged_padded_output_layers[: model_config.vocab_size, :]
228
+ hf_model.lm_head.load_state_dict({"weight": merged_output_layers}, strict=True)
229
+
230
+
231
+ def check_padded_vocab_size(train_args, orig_vocab_size):
232
+ """Pad vocab size so it is divisible by model parallel size and
233
+ still having GPU friendly size."""
234
+
235
+ after = orig_vocab_size
236
+ multiple = (
237
+ train_args.make_vocab_size_divisible_by * train_args.tensor_model_parallel_size
238
+ )
239
+ while (after % multiple) != 0:
240
+ after += 1
241
+ assert (
242
+ train_args.padded_vocab_size == after
243
+ ), "Mismatched vocab size and padded vocab size."
244
+
245
+
246
+ def get_train_args(state_dict):
247
+ args = state_dict.get("args", None)
248
+ assert args is not None
249
+ return args
250
+
251
+
252
+ def get_model_config(train_args, vocab_size):
253
+ config = LlamaConfig()
254
+ check_padded_vocab_size(train_args, vocab_size)
255
+ config.vocab_size = vocab_size
256
+ # config.vocab_size = train_args.padded_vocab_size
257
+ config.max_position_embeddings = train_args.max_position_embeddings
258
+ config.hidden_size = train_args.hidden_size
259
+ config.num_hidden_layers = train_args.num_layers
260
+ config.num_attention_heads = train_args.num_attention_heads
261
+ config.num_key_value_heads = train_args.num_query_groups
262
+ config.intermediate_size = train_args.ffn_hidden_size
263
+ if hasattr(train_args, "rope_base"):
264
+ config.rope_theta = train_args.rope_base
265
+ config.pad_token_id = 0
266
+ config.torch_dtype = train_args.params_dtype
267
+ return config
268
+
269
+
270
+ def load_state_dicts(input_dir):
271
+ state_dicts = [
272
+ torch.load(os.path.join(f.path, "model_optim_rng.pt"), map_location="cpu")
273
+ for f in os.scandir(input_dir)
274
+ if f.is_dir()
275
+ ]
276
+ args = get_train_args(state_dicts[0])
277
+ if args.transformer_pipeline_model_parallel_size == 1:
278
+ return state_dicts, args
279
+
280
+ state_dicts = []
281
+ tp_size = args.tensor_model_parallel_size
282
+ pp_size = args.transformer_pipeline_model_parallel_size
283
+ num_layers_per_pile = args.num_layers // pp_size
284
+ for tp_index in range(tp_size):
285
+ model_file = f"{input_dir}/mp_rank_{tp_index:02d}_000/model_optim_rng.pt"
286
+ print(f"loading {model_file}")
287
+ state_dict = torch.load(
288
+ model_file,
289
+ map_location="cpu",
290
+ )
291
+ lm = state_dict["model"]["language_model"]
292
+ encoder = lm["encoder"]
293
+ for pp_index in range(1, pp_size):
294
+ model_file = f"{input_dir}/mp_rank_{tp_index:02d}_{pp_index:03d}/model_optim_rng.pt"
295
+ this_state_dict = torch.load(
296
+ model_file,
297
+ map_location="cpu",
298
+ )
299
+ print(f"loading {model_file}")
300
+ this_lm = this_state_dict["model"]["language_model"]
301
+ this_encoder = this_lm["encoder"]
302
+
303
+ if pp_index == pp_size - 1:
304
+ lm["output_layer"] = this_lm["output_layer"]
305
+ encoder["final_norm.weight"] = this_encoder[
306
+ "final_norm.weight"
307
+ ]
308
+
309
+ for layer_index in range(num_layers_per_pile):
310
+ this_layer_index = layer_index + num_layers_per_pile * pp_index
311
+ if args.num_attention_heads == args.num_query_groups:
312
+ encoder = check_assign(
313
+ encoder,
314
+ this_layer_index,
315
+ this_encoder,
316
+ layer_index,
317
+ key_list=transformer_layer_name_list["query_key_value"],
318
+ )
319
+ else:
320
+ for key in ["query", "key_value", "query_key_value"]:
321
+ encoder = check_assign(
322
+ encoder,
323
+ this_layer_index,
324
+ this_encoder,
325
+ layer_index,
326
+ key_list=transformer_layer_name_list[key],
327
+ )
328
+ for key in transformer_layer_name_list.keys():
329
+ if key not in ("query_key_value", "query", "key_value"):
330
+ encoder = check_assign(
331
+ encoder,
332
+ this_layer_index,
333
+ this_encoder,
334
+ layer_index,
335
+ key_list=transformer_layer_name_list[key],
336
+ )
337
+ state_dicts.append(state_dict)
338
+
339
+ return state_dicts, args
340
+
341
+
342
+ def main():
343
+ parser = argparse.ArgumentParser()
344
+ parser.add_argument(
345
+ "--input-dir",
346
+ type=str,
347
+ help="Path to the megatron checkpoint dir",
348
+ )
349
+ parser.add_argument(
350
+ "--output-dir",
351
+ type=str,
352
+ help="Path to the huggingface checkpoint dir",
353
+ )
354
+ parser.add_argument(
355
+ "--vocab-size",
356
+ type=int,
357
+ default=64000,
358
+ help="unpadded tokenizer vocab size",
359
+ )
360
+ args = parser.parse_args()
361
+
362
+ print("Load megatron checkpoint")
363
+ state_dicts, train_args = load_state_dicts(args.input_dir)
364
+
365
+ model_config = get_model_config(train_args, args.vocab_size)
366
+ print(f"Model config: {model_config}", flush=True)
367
+
368
+
369
+ print("Create hf model", flush=True)
370
+ # with accelerate.init_empty_weights():
371
+ hf_model = LlamaForCausalLM(model_config)
372
+ hf_model = hf_model.to(torch.bfloat16)
373
+
374
+ print("convert megatron to hf", flush=True)
375
+ convert_megatron_checkpoint(hf_model, state_dicts, model_config)
376
+
377
+ print("save hf model", flush=True)
378
+ hf_model.save_pretrained(args.output_dir, safe_serialization=False)
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()
m-a-p.png ADDED