# ruff: noqa: E402 """ This module converts a transformers LlamaForCausalLM to a brrr model Command: torchrun --nproc_per_node=1 convert_trfrs_to_brrr.py \ --model_name mistralai/Mistral-7B-v0.1 \ --save_path ./pretrained/Mistral-7B-v0.1 """ import argparse import sys from dataclasses import asdict from pathlib import Path from typing import Dict, List import torch from brrr.trainer import DistributedTrainer sys.path.append(Path(__file__).parent.parent.as_posix()) import os from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( AllForwardAllBackwardPipelineEngine, ) from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from transformers import MistralConfig as MistralConfig_trfs, MistralForCausalLM import nanotron.distributed as dist from nanotron.config import ParallelismArgs, RecomputeGranularity from nanotron.parallel.context import ParallelContext from nanotron.models import build_model from nanotron.trainer import mark_tied_parameters from nanotron.serialize import save_meta, save_weights, save from modeling_mistral import MistralForTraining from config_mistral_7b import PARALLELISM as PARALLELISM_BRRR, CONFIG as CONFIG_BRRR def get_args(): parser = argparse.ArgumentParser(description="Convert transformers weights to brrr weights") parser.add_argument("--model_name", type=str, default="mistralai/Mistral-7B-v0.1") parser.add_argument("--save_path", type=str, default="pretrained/Mistral-7B-v0.1") parser.add_argument("--dp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--tp", type=int, default=1) return parser.parse_args() def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size): return ( tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size) .transpose(1, 2) .contiguous() .view(num_heads * per_head_hidden_size, hidden_size) ) def get_transformers_weight( name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MistralForCausalLM, get_grad: bool = False ) -> torch.Tensor: """From our brrr implementation, we get the equivalent tensor in transformers implementation""" config = ref_module.config brrr_prefix = "model." assert name.startswith(brrr_prefix) name = name[len(brrr_prefix) :] path = name.split(".") path.remove("pp_block") name = ".".join(path) if get_grad is False: def get_tensor(path: str): return ref_module_state_dict[path] def get_tensors(path: List[str]): return [get_tensor(p) for p in path] else: def get_tensor(path: str): weight = ref_module.get_parameter(path) return weight.grad def get_tensors(path: List[str]): return [get_tensor(p) for p in path] if name == "token_position_embeddings.token_embedding.weight": return get_tensor("model.embed_tokens.weight") elif name == "lm_head.weight": # This only used when weights are not shared return get_tensor("lm_head.weight") elif name == "final_layer_norm.weight": return get_tensor("model.norm.weight") if path[0] == "decoder": transformer_path = ["model"] + ["layers"] + [path[1]] if path[2] == "attn": path[2] = "self_attn" if path[2] == "ff": path[2] = "mlp" if path[3] == "qkv_proj": proj_names = ["q_proj", "k_proj", "v_proj"] tensor_list = get_tensors( [".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names] ) # Permute q/k per_head_hidden_size = config.hidden_size // config.num_attention_heads # Permute q print(f"Permuting q {tensor_list[0].shape}") tensor_list[0] = permute_for_rotary( tensor=tensor_list[0], num_heads=config.num_attention_heads, per_head_hidden_size=per_head_hidden_size, hidden_size=config.hidden_size, ) # Permute k print(f"Permuting k {tensor_list[1].shape}") tensor_list[1] = permute_for_rotary( tensor=tensor_list[1], num_heads=config.num_key_value_heads, per_head_hidden_size=per_head_hidden_size, hidden_size=config.hidden_size, ) return torch.cat(tensor_list, dim=0) if path[3] == "gate_up_proj": tensor_list = get_tensors( [ ".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in ["gate_proj", "up_proj"] ] ) return torch.cat(tensor_list, dim=0) return get_tensor(".".join(transformer_path + path[2:])) else: raise ValueError(f"Couldn't find transformer equivalent of {name}") def convert_trfrs_to_brrr(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"): # check save_path doesnt exist or is empty save_path = Path(save_path) # assert not save_path.exists() or len(list(save_path.iterdir())) == 0, f"save_path {save_path} is not empty" parallel_config = PARALLELISM_BRRR parallel_config.dp = dp parallel_config.pp = pp parallel_config.tp = tp # Initialise all process groups parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, pipeline_parallel_size=parallel_config.pp, tensor_parallel_size=parallel_config.tp, ) # params dtype = torch.bfloat16 # Flash attention doesn't support fp32 # Initialise brrr model model_config_brrr = CONFIG_BRRR.model.model_config model = build_model( model_builder=lambda: MistralForTraining( config=model_config_brrr, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, ), dtype=dtype, parallel_context=parallel_context, device=torch.device("cpu"), ) # Initialise transformers model device_map = {} current_pp_rank = dist.get_rank(group=parallel_context.pp_pg) device_map["model.embed_tokens"] = ( model.model.token_position_embeddings.rank if current_pp_rank == model.model.token_position_embeddings.rank else "meta" ) for i in range(model_config_brrr.num_hidden_layers): device_map[f"model.layers.{i}"] = ( model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta" ) device_map["model.norm"] = ( model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta" ) device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta" model_ref = MistralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map) # Copy weights from trfrs to brrr ref_state_dict = model_ref.state_dict() for name, param in model.named_parameters(): print(f"Syncing {name}") ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref) param_is_tp_sharded = ( isinstance(param, NanotronParameter) and param.is_sharded and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg ) if param_is_tp_sharded: sharded_info = param.get_sharded_info() # copy param data (not just the reference) with torch.no_grad(): for local_global_slices_pair in sharded_info.local_global_slices_pairs: local_slices = local_global_slices_pair.local_slices global_slices = local_global_slices_pair.global_slices param[local_slices].copy_(ref_param[global_slices]) else: assert ( ref_param.shape == param.shape ), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}" # copy param data (not just the reference) with torch.no_grad(): param.copy_(ref_param) ref_param = None # torch.cuda.empty_cache() # TODO @nouamanetazi: assert weights are the same # Marks parameters as NanotronParameters mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) sanity_check(root_module=model) checkpoint_metadata = { "last_train_step": 0, "consumed_train_samples": 0, } save(config=CONFIG_BRRR, model=model, optimizer=None, lr_scheduler=None, parallel_context=parallel_context, root_folder=save_path, should_save_optimizer=False, should_save_lr_scheduler=False, checkpoint_metadata=checkpoint_metadata, sanity_checks=False) # save_weights(model=model, parallel_context=parallel_context, root_folder=save_path) # save_meta(root_folder=save_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) if dist.get_rank(parallel_context.world_pg) == 0: print(save_path) import json with open(save_path / "model_config.json", mode="w") as fo: fo.write(json.dumps(asdict(CONFIG_BRRR.model.model_config), indent=4)) def main(): args = get_args() convert_trfrs_to_brrr(**vars(args)) if __name__ == "__main__": main()