diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index a0fbe4680..50c7ed738 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -17,10 +17,10 @@ import json import os import shutil import warnings - +from typing import List import torch -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast, GenerationConfig from transformers.convert_slow_tokenizer import TikTokenConverter @@ -85,8 +85,12 @@ NUM_SHARDS = { "65B": 8, "70B": 8, "70Bf": 8, + "405B": 8, + "405B-MP16": 16, } +CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048} + def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) @@ -107,9 +111,10 @@ def write_model( input_base_path, model_size=None, safe_serialization=True, - llama_version=1, + llama_version="1", vocab_size=None, num_shards=None, + instruct=False, ): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") @@ -125,18 +130,11 @@ def write_model( dims_per_head = dim // n_heads base = params.get("rope_theta", 10000.0) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - if base > 10000.0 and llama_version != 3: + if base > 10000.0 and float(llama_version) < 3: max_position_embeddings = 16384 else: - # Depending on the Llama version, the default max_position_embeddings has different values. - if llama_version == 1: - max_position_embeddings = 2048 - elif llama_version == 2: - max_position_embeddings = 4096 - elif llama_version == 3: - max_position_embeddings = 8192 - - vocab_size = vocab_size if vocab_size is not None else 32000 + max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version] + if params.get("n_kv_heads", None) is not None: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA num_key_value_heads_per_shard = num_key_value_heads // num_shards @@ -144,8 +142,7 @@ def write_model( else: # compatibility with other checkpoints num_key_value_heads = n_heads num_key_value_heads_per_shard = n_heads_per_shard - key_value_dim = dims_per_head * num_key_value_heads - print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim) + key_value_dim = dim # permute for sliced rotary def permute(w, n_heads, dim1=dim, dim2=dim): @@ -159,11 +156,9 @@ def write_model( loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") else: # Sharded - loaded = [ - torch.load(os.path.join(input_base_path, file), map_location="cpu") - for file in os.listdir(input_base_path) - if file.endswith(".pth") - ] + checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")]) + print("Loading in order:", checkpoint_list) + loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list] param_count = 0 index_dict = {"weight_map": {}} for layer_i in range(n_layers): @@ -263,7 +258,7 @@ def write_model( "lm_head.weight": loaded["output.weight"], } else: - concat_dim = 0 if llama_version == 3 else 1 + concat_dim = 0 if llama_version in ['3', '3.1'] else 1 state_dict = { "model.norm.weight": loaded[0]["norm.weight"], "model.embed_tokens.weight": torch.cat( @@ -282,6 +277,18 @@ def write_model( write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + + if llama_version in ['3', '3.1']: + bos_token_id = 128000 + + if instruct: + eos_token_id = [128001, 128009] + else: + eos_token_id = 128001 + else: + bos_token_id = 1 + eos_token_id = 2 + config = LlamaConfig( hidden_size=dim, intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), @@ -292,11 +299,21 @@ def write_model( vocab_size=vocab_size, rope_theta=base, max_position_embeddings=max_position_embeddings, - bos_token_id=128000 if llama_version == 3 else 1, - eos_token_id=128001 if llama_version == 3 else 2, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, ) config.save_pretrained(tmp_model_path) + if instruct: + generation_config = GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + ) + generation_config.save_pretrained(tmp_model_path) + # Make space so we can load the model properly now. del state_dict del loaded @@ -313,7 +330,7 @@ def write_model( class Llama3Converter(TikTokenConverter): - def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs): + def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs): super().__init__(vocab_file, **kwargs) tokenizer = self.converted() chat_template = ( @@ -327,34 +344,29 @@ class Llama3Converter(TikTokenConverter): "{% endfor %}" "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" ) - num_reserved_special_tokens = 256 - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)] tokenizer.add_special_tokens(special_tokens) + print("MODEL MAX LENGTH", model_max_length) + self.tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>", + eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", chat_template=chat_template, model_input_names=["input_ids", "attention_mask"], + model_max_length=model_max_length, ) -def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2): +def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False): tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast - if llama_version == 3: - tokenizer = Llama3Converter(input_tokenizer_path).tokenizer + if llama_version in ["3", "3.1"]: + tokenizer = Llama3Converter( + input_tokenizer_path, + special_tokens, + instruct, + model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version] + ).tokenizer else: tokenizer = tokenizer_class(input_tokenizer_path) print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") @@ -362,6 +374,37 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2): return tokenizer +DEFAULT_LLAMA_SPECIAL_TOKENS = { + "3": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)], + "3.1": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|reserved_special_token_2|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], +} + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -383,9 +426,9 @@ def main(): # Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. parser.add_argument( "--llama_version", - choices=[1, 2, 3], - default=1, - type=int, + choices=["1", "2", "3", "3.1"], + default="1", + type=str, help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", ) parser.add_argument( @@ -394,11 +437,34 @@ def main(): type=int, help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", ) + parser.add_argument( + "--special_tokens", + default=None, + type=List[str], + help="The list of special tokens that should be added to the model.", + ) + parser.add_argument( + "--instruct", + default=False, + type=bool, + help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.", + ) args = parser.parse_args() if args.model_size is None and args.num_shards is None: raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") + if args.special_tokens is None: + args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS[str(args.llama_version)] + spm_path = os.path.join(args.input_dir, "tokenizer.model") - vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version)) + vocab_size = len( + write_tokenizer( + args.output_dir, + spm_path, + llama_version=args.llama_version, + special_tokens=args.special_tokens, + instruct=args.instruct + ) + ) if args.model_size != "tokenizer_only": write_model( model_path=args.output_dir, @@ -408,6 +474,7 @@ def main(): llama_version=args.llama_version, vocab_size=vocab_size, num_shards=args.num_shards, + instruct=args.instruct ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8cbe8fe35..65b4bb56b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -90,6 +90,29 @@ class LlamaRMSNorm(nn.Module): ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): @@ -99,6 +122,7 @@ class LlamaRotaryEmbedding(nn.Module): self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = apply_scaling(inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings