import glob import json import torch from safetensors import safe_open from safetensors.torch import save_file def main(): print("densifying...") files = glob.glob("model-*-of-*.safetensors") densified = {} for mf in sorted(files): st = safe_open(mf, framework='pt') print(f"processing {mf}") for key in st.keys(): tensor = st.get_tensor(key) nk = key.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight") nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight") nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight") if "block_sparse_moe.gate.weight" not in key: densified[nk] = tensor print(f"+ {key} -> {nk}") else: print(f"- {key}") save_file(densified, f"{mf}-x") densified = {} with open("config.json", "r") as read_conf_file: conf = json.load(read_conf_file) with open("config.json.old", "w") as conf_file: json.dump(conf, conf_file) if 'num_local_experts' in conf: del conf['num_local_experts'] if 'num_experts_per_tok' in conf: del conf['num_experts_per_tok'] if 'output_router_logits' in conf: del conf['output_router_logits'] if 'router_aux_loss_coef' in conf: del conf['router_aux_loss_coef'] if 'router_jitter_noise' in conf: del conf['router_jitter_noise'] conf['architectures'] = ['MistralForCausalLM'] conf['model_type'] = 'mistral' conf['_name_or_path'] = 'lodrick-the-lafted/Densefin-291-Mistral-22B' with open("config.json", "w") as write_conf_file: json.dump(conf, write_conf_file, indent=4) with open("model.safetensors.index.json", "r") as index_file: index = json.load(index_file) new_index = {} new_index['metadata'] = index['metadata'] wm = index['weight_map'] nwm = {} for weight in wm.keys(): nk = weight.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight") nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight") nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight") #print(f"{weight} -> {nk}") if "block_sparse_moe.gate.weight" not in weight: nwm[nk] = wm[weight] new_index['weight_map'] = nwm with open("model.safetensors.index.json.old", "w") as backup: json.dump(index, backup, indent=4) with open("model.safetensors.index.json", "w") as new_st_index: json.dump(new_index, new_st_index, indent=4) if __name__ == "__main__": main()