# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2023 Xinyang Geng
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts LLaMA model checkpoint trained by EsayLM to the
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
# by HuggingFace transformers.
import gc
import json
import math
import os
import shutil
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
import flax
from flax.traverse_util import flatten_dict
import torch
from transformers import LlamaConfig, LlamaForCausalLM
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.jax_utils import float_tensor_to_dtype
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
load_checkpoint='',
tokenizer_path='',
model_size='13b',
output_dir='',
)
LLAMA_STANDARD_CONFIGS = {
'small': {
'vocab_size': 64256,
'dim': 768,
'intermediate_size': 3072,
'n_layers': 12,
'n_heads': 12,
'norm_eps': 1e-6,
},
'medium': {
'vocab_size': 64256,
'dim': 1024,
'intermediate_size': 4096,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'large': {
'vocab_size': 64256,
'dim': 1536,
'intermediate_size': 6144,
'n_layers': 24,
'n_heads': 16,
'norm_eps': 1e-6,
},
'xlarge': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 8192,
'n_layers': 24,
'n_heads': 32,
'norm_eps': 1e-6,
},
'1b': {
'vocab_size': 64256,
'dim': 2048,
'intermediate_size': 5504,
'n_layers': 22,
'n_heads': 16,
'norm_eps': 1e-6,
},
'3b': {
'vocab_size': 64256,
'dim': 3200,
'intermediate_size': 8640,
'n_layers': 26,
'n_heads': 32,
'norm_eps': 1e-6,
},
'7b': {
'vocab_size': 64256,
'dim': 4096,
'intermediate_size': 11008,
'n_layers': 32,
'n_heads': 32,
'norm_eps': 1e-6,
},
'13b': {
'vocab_size': 64256,
'dim': 5120,
'intermediate_size': 13824,
'n_layers': 40,
'n_heads': 40,
'norm_eps': 1e-6,
},
'30b': {
'vocab_size': 64256,
'dim': 6656,
'intermediate_size': 17920,
'n_layers': 60,
'n_heads': 52,
'norm_eps': 1e-6,
},
'65b': {
'vocab_size': 64256,
'dim': 8192,
'intermediate_size': 22016,
'n_layers': 80,
'n_heads': 64,
'norm_eps': 1e-5,
},
}
def match_keywords(string, positives, negatives):
for positive in positives:
if positive not in string:
return False
for negative in negatives:
if negative in string:
return False
return True
def load_and_convert_checkpoint(path):
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')
torch_params = {}
for key, tensor in flax_params.items():
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
tensor = tensor.T
torch_params[key] = torch.tensor(
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
)
return torch_params
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(loaded, model_path, model_size):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = LLAMA_STANDARD_CONFIGS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
}
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
"model.norm.weight": loaded["transformer.ln_f.kernel"],
"lm_head.weight": loaded["lm_head.kernel"],
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = LlamaConfig(
vocab_size=params["vocab_size"],
hidden_size=dim,
intermediate_size=params["intermediate_size"],
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
# Avoid saving this as part of the config.
print("Model parameter count", model.num_parameters())
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=True)
shutil.rmtree(tmp_model_path)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json(
{
"bos_token": {
"content": "",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"content": "",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"content": "",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "special_tokens_map.json")
)
write_json(
{
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 2048,
"pad_token": None,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": {
"__type": "AddedToken",
"content": "",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"eos_token": {
"__type": "AddedToken",
"content": "",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
"unk_token": {
"__type": "AddedToken",
"content": "",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False
},
},
os.path.join(tokenizer_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
def main(argv):
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
# write_tokenizer(
# tokenizer_path=FLAGS.output_dir,
# input_tokenizer_path=FLAGS.tokenizer_path,
# )
write_model(
load_and_convert_checkpoint(FLAGS.load_checkpoint),
model_path=FLAGS.output_dir,
model_size=FLAGS.model_size,
)
if __name__ == "__main__":
mlxu.run(main)