""" Usage: python convert_hf_to_easylm.py \ --checkpoint_dir /path/hf_format_dir/ \ --output_file /path/easylm_format.stream \ --model_size 7b \ --streaming """ import time from pathlib import Path import argparse import mlxu import torch import flax from EasyLM.checkpoint import StreamingCheckpointer LLAMA_STANDARD_CONFIGS = { '1b': { 'dim': 2048, 'intermediate_size': 5504, 'n_layers': 22, 'n_heads': 16, 'norm_eps': 1e-6, }, '3b': { 'dim': 3200, 'intermediate_size': 8640, 'n_layers': 26, 'n_heads': 32, 'norm_eps': 1e-6, }, "7b": { "dim": 4096, "intermediate_size": 11008, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-6, }, "13b": { "dim": 5120, "intermediate_size": 13824, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-6, }, "30b": { "dim": 6656, "intermediate_size": 17920, "n_layers": 60, "n_heads": 52, "norm_eps": 1e-6, }, "65b": { "dim": 8192, "intermediate_size": 22016, "n_layers": 80, "n_heads": 64, "norm_eps": 1e-5, }, } def inverse_permute(params, w): n_layers = params["n_layers"] n_heads = params["n_heads"] dim = params["dim"] reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim) transposed_w = reshaped_w.transpose(0, 2, 1, 3) inverted_w = transposed_w.reshape(dim, dim) return inverted_w def main(args): start = time.time() params = LLAMA_STANDARD_CONFIGS[args.model_size] ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin")) ckpt = {} for i, ckpt_path in enumerate(ckpt_paths): checkpoint = torch.load(ckpt_path, map_location="cpu") for k, v in checkpoint.items(): if k.startswith("model."): k = k[6:] ckpt[k] = v print(f"Start convert weight to easylm format...") jax_weights = { "transformer": { "wte": {"embedding": ckpt["embed_tokens.weight"].numpy()}, "ln_f": {"kernel": ckpt["norm.weight"].numpy()}, "h": { "%d" % (layer): { "attention": { "wq": { "kernel": inverse_permute( params, ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(), ).transpose() }, "wk": { "kernel": inverse_permute( params, ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(), ).transpose() }, "wv": { "kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"] .numpy() .transpose() }, "wo": { "kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"] .numpy() .transpose() }, }, "feed_forward": { "w1": { "kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"] .numpy() .transpose() }, "w2": { "kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"] .numpy() .transpose() }, "w3": { "kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"] .numpy() .transpose() }, }, "attention_norm": { "kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy() }, "ffn_norm": { "kernel": ckpt[ f"layers.{layer}.post_attention_layernorm.weight" ].numpy() }, } for layer in range(params["n_layers"]) }, }, "lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()}, } print(f"Convert weight to easylm format finished...") print(f"Start to save...") if args.streaming: StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file) else: with mlxu.open_file(args.output_file, "wb") as fout: fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True)) print( f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}" ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="hf to easylm format script") parser.add_argument( "--checkpoint_dir", type=str, help="Need to be converted model weight dir. it is a dir", ) parser.add_argument( "--output_file", type=str, help="Save model weight file path, it is a file." ) parser.add_argument( "--model_size", type=str, default="7b", choices=["7b", "13b", "30b", "65b"], help="model size", ) parser.add_argument( "--streaming", action="store_true", default=True, help="whether is model weight saved stream format", ) args = parser.parse_args() print(f"checkpoint_dir: {args.checkpoint_dir}") print(f"output_file: {args.output_file}") print(f"model_size: {args.model_size}") print(f"streaming: {args.streaming}") main(args)