{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "66851a9c-d852-4a25-8cc7-1b7c03d1b3c2", "metadata": {}, "outputs": [], "source": [ "from safetensors.torch import load_file\n", "import torch\n", "\n", "model = load_file(\"model_original.safetensors\", device=\"cpu\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "6775e2ae-a543-401d-9f81-c450f3eb5910", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.embed_tokens.weight\n", "model.layers.0.input_layernorm.weight\n", "model.layers.0.mlp.down_proj.weight\n", "model.layers.0.mlp.gate_proj.weight\n", "model.layers.0.mlp.up_proj.weight\n", "model.layers.0.post_attention_layernorm.weight\n", "model.layers.0.self_attn.k_proj.weight\n", "model.layers.0.self_attn.o_proj.weight\n", "model.layers.0.self_attn.q_proj.weight\n", "model.layers.0.self_attn.v_proj.weight\n", "model.layers.1.input_layernorm.weight\n", "model.layers.1.mlp.down_proj.weight\n", "model.layers.1.mlp.gate_proj.weight\n", "model.layers.1.mlp.up_proj.weight\n", "model.layers.1.post_attention_layernorm.weight\n", "model.layers.1.self_attn.k_proj.weight\n", "model.layers.1.self_attn.o_proj.weight\n", "model.layers.1.self_attn.q_proj.weight\n", "model.layers.1.self_attn.v_proj.weight\n", "model.layers.2.input_layernorm.weight\n", "model.layers.2.mlp.down_proj.weight\n", "model.layers.2.mlp.gate_proj.weight\n", "model.layers.2.mlp.up_proj.weight\n", "model.layers.2.post_attention_layernorm.weight\n", "model.layers.2.self_attn.k_proj.weight\n", "model.layers.2.self_attn.o_proj.weight\n", "model.layers.2.self_attn.q_proj.weight\n", "model.layers.2.self_attn.v_proj.weight\n", "model.layers.3.input_layernorm.weight\n", "model.layers.3.mlp.down_proj.weight\n", "model.layers.3.mlp.gate_proj.weight\n", "model.layers.3.mlp.up_proj.weight\n", "model.layers.3.post_attention_layernorm.weight\n", "model.layers.3.self_attn.k_proj.weight\n", "model.layers.3.self_attn.o_proj.weight\n", "model.layers.3.self_attn.q_proj.weight\n", "model.layers.3.self_attn.v_proj.weight\n", "model.layers.4.input_layernorm.weight\n", "model.layers.4.mlp.down_proj.weight\n", "model.layers.4.mlp.gate_proj.weight\n", "model.layers.4.mlp.up_proj.weight\n", "model.layers.4.post_attention_layernorm.weight\n", "model.layers.4.self_attn.k_proj.weight\n", "model.layers.4.self_attn.o_proj.weight\n", "model.layers.4.self_attn.q_proj.weight\n", "model.layers.4.self_attn.v_proj.weight\n", "model.layers.5.input_layernorm.weight\n", "model.layers.5.mlp.down_proj.weight\n", "model.layers.5.mlp.gate_proj.weight\n", "model.layers.5.mlp.up_proj.weight\n", "model.layers.5.post_attention_layernorm.weight\n", "model.layers.5.self_attn.k_proj.weight\n", "model.layers.5.self_attn.o_proj.weight\n", "model.layers.5.self_attn.q_proj.weight\n", "model.layers.5.self_attn.v_proj.weight\n", "model.norm.weight\n" ] } ], "source": [ "for name, tensor in model.items():\n", " print(name)" ] }, { "cell_type": "code", "execution_count": 25, "id": "8b06f3c7-927d-4148-950c-5e1c93a54b75", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.embed_tokens.weight torch.Size([32000, 288])\n", "model.norm.weight torch.Size([288])\n", "lm_head.weight torch.Size([32000, 288])\n", "model.layers.0.input_layernorm.weight torch.Size([288])\n", "model.layers.0.post_attention_layernorm.weight torch.Size([288])\n", "model.layers.0.self_attn.k_proj.weight torch.Size([288, 288])\n", "model.layers.0.self_attn.o_proj.weight torch.Size([288, 288])\n", "model.layers.0.self_attn.q_proj.weight torch.Size([288, 288])\n", "model.layers.0.self_attn.v_proj.weight torch.Size([288, 288])\n", "model.layers.0.block_sparse_moe.gate.weight torch.Size([4, 288])\n", "model.layers.0.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n", "model.layers.0.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n", "model.layers.0.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n", "model.layers.0.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n", "model.layers.0.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n", "model.layers.0.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n", "model.layers.1.input_layernorm.weight torch.Size([288])\n", "model.layers.1.post_attention_layernorm.weight torch.Size([288])\n", "model.layers.1.self_attn.k_proj.weight torch.Size([288, 288])\n", "model.layers.1.self_attn.o_proj.weight torch.Size([288, 288])\n", "model.layers.1.self_attn.q_proj.weight torch.Size([288, 288])\n", "model.layers.1.self_attn.v_proj.weight torch.Size([288, 288])\n", "model.layers.1.block_sparse_moe.gate.weight torch.Size([4, 288])\n", "model.layers.1.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n", "model.layers.1.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n", "model.layers.1.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n", "model.layers.1.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n", "model.layers.1.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n", "model.layers.1.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n", "model.layers.2.input_layernorm.weight torch.Size([288])\n", "model.layers.2.post_attention_layernorm.weight torch.Size([288])\n", "model.layers.2.self_attn.k_proj.weight torch.Size([288, 288])\n", "model.layers.2.self_attn.o_proj.weight torch.Size([288, 288])\n", "model.layers.2.self_attn.q_proj.weight torch.Size([288, 288])\n", "model.layers.2.self_attn.v_proj.weight torch.Size([288, 288])\n", "model.layers.2.block_sparse_moe.gate.weight torch.Size([4, 288])\n", "model.layers.2.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n", "model.layers.2.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n", "model.layers.2.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n", "model.layers.2.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n", "model.layers.2.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n", "model.layers.2.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n", "model.layers.3.input_layernorm.weight torch.Size([288])\n", "model.layers.3.post_attention_layernorm.weight torch.Size([288])\n", "model.layers.3.self_attn.k_proj.weight torch.Size([288, 288])\n", "model.layers.3.self_attn.o_proj.weight torch.Size([288, 288])\n", "model.layers.3.self_attn.q_proj.weight torch.Size([288, 288])\n", "model.layers.3.self_attn.v_proj.weight torch.Size([288, 288])\n", "model.layers.3.block_sparse_moe.gate.weight torch.Size([4, 288])\n", "model.layers.3.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n", "model.layers.3.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n", "model.layers.3.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n", "model.layers.3.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n", "model.layers.3.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n", "model.layers.3.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n", "model.layers.4.input_layernorm.weight torch.Size([288])\n", "model.layers.4.post_attention_layernorm.weight torch.Size([288])\n", "model.layers.4.self_attn.k_proj.weight torch.Size([288, 288])\n", "model.layers.4.self_attn.o_proj.weight torch.Size([288, 288])\n", "model.layers.4.self_attn.q_proj.weight torch.Size([288, 288])\n", "model.layers.4.self_attn.v_proj.weight torch.Size([288, 288])\n", "model.layers.4.block_sparse_moe.gate.weight torch.Size([4, 288])\n", "model.layers.4.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n", "model.layers.4.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n", "model.layers.4.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n", "model.layers.4.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n", "model.layers.4.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n", "model.layers.4.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n", "model.layers.5.input_layernorm.weight torch.Size([288])\n", "model.layers.5.post_attention_layernorm.weight torch.Size([288])\n", "model.layers.5.self_attn.k_proj.weight torch.Size([288, 288])\n", "model.layers.5.self_attn.o_proj.weight torch.Size([288, 288])\n", "model.layers.5.self_attn.q_proj.weight torch.Size([288, 288])\n", "model.layers.5.self_attn.v_proj.weight torch.Size([288, 288])\n", "model.layers.5.block_sparse_moe.gate.weight torch.Size([4, 288])\n", "model.layers.5.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n", "model.layers.5.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n", "model.layers.5.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n", "model.layers.5.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n", "model.layers.5.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n", "model.layers.5.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n" ] } ], "source": [ "N_EXPERTS = 4\n", "N_LAYERS = 6\n", "N_FF = 768\n", "N_EMBD = 288\n", "\n", "moe_model = dict()\n", "def copy_tensor(name, new_name = None):\n", " new_name = name if new_name is None else new_name\n", " moe_model[new_name] = torch.clone(model[name])\n", "\n", "copy_tensor('model.embed_tokens.weight')\n", "copy_tensor('model.norm.weight')\n", "copy_tensor('model.embed_tokens.weight', 'lm_head.weight')\n", "\n", "torch.manual_seed(0)\n", "for il in range(N_LAYERS):\n", " copy_tensor(f'model.layers.{il}.input_layernorm.weight')\n", " copy_tensor(f'model.layers.{il}.post_attention_layernorm.weight')\n", " copy_tensor(f'model.layers.{il}.self_attn.k_proj.weight')\n", " copy_tensor(f'model.layers.{il}.self_attn.o_proj.weight')\n", " copy_tensor(f'model.layers.{il}.self_attn.q_proj.weight')\n", " copy_tensor(f'model.layers.{il}.self_attn.v_proj.weight')\n", " moe_model[f'model.layers.{il}.block_sparse_moe.gate.weight'] = torch.rand(N_EXPERTS, N_EMBD)\n", " for ex in range(N_EXPERTS):\n", " copy_tensor(f'model.layers.{il}.mlp.gate_proj.weight', f'model.layers.{il}.block_sparse_moe.experts.{ex}.w1.weight')\n", " copy_tensor(f'model.layers.{il}.mlp.down_proj.weight', f'model.layers.{il}.block_sparse_moe.experts.{ex}.w2.weight')\n", " copy_tensor(f'model.layers.{il}.mlp.up_proj.weight', f'model.layers.{il}.block_sparse_moe.experts.{ex}.w3.weight')\n", "\n", "for name, tensor in moe_model.items():\n", " print(name, tensor.shape)" ] }, { "cell_type": "code", "execution_count": 26, "id": "19817bec-448f-4619-8772-2b3c77f0a1c2", "metadata": {}, "outputs": [], "source": [ "from safetensors.torch import save_file\n", "\n", "save_file(moe_model, \"model.safetensors\", metadata={\"format\": \"pt\"})" ] }, { "cell_type": "code", "execution_count": null, "id": "e5bfd2cb-f53b-4285-bf5d-52a6c23779e0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 15, "id": "e29a4b7e-e390-4d69-857c-02fc6065e33d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "\n", "index_json = {\n", " \"metadata\": {\n", " \"total_size\": os.path.getsize(\"model.safetensors\"),\n", " \"format\": \"safetensors\"\n", " },\n", " \"weight_map\": {}\n", "}\n", "\n", "for name, _ in moe_model.items():\n", " index_json[\"weight_map\"][name] = \"model.safetensors\"\n", "\n", "#with open(\"model.safetensors.index.json\", 'w') as json_file:\n", "# json.dump(index_json, json_file, indent=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "c7e0736c-0139-4808-8943-c9eba5dcfc76", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }