ngxson commited on
Commit
0495e49
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gguf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .ipynb_checkpoints
2
+ wandb
3
+
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # stories15M_MOE
6
+
7
+ This model is [ModelCloud/tinyllama-15M-stories](https://huggingface.co/ModelCloud/tinyllama-15M-stories) repeated 4 times to make 4 experts.
8
+
9
+ The model is used for testing, not intended to be used in production (unless your product is some kind of bedtime story teller)
10
+
11
+ Weight of router is initialized randomly
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MixtralForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 288,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 768,
12
+ "max_position_embeddings": 256,
13
+ "model_type": "mixtral",
14
+ "num_attention_heads": 6,
15
+ "num_experts_per_tok": 2,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 6,
18
+ "num_local_experts": 4,
19
+ "output_router_logits": false,
20
+ "rms_norm_eps": 1e-05,
21
+ "rope_theta": 1000000.0,
22
+ "router_aux_loss_coef": 0.02,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.36.0.dev0",
27
+ "use_cache": true,
28
+ "vocab_size": 32000
29
+ }
data.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ From fairest creatures we desire increase,
2
+ That thereby beauty's rose might never die,
3
+ But as the riper should by time decease,
4
+ His tender heir might bear his memory:
5
+ But thou contracted to thine own bright eyes,
6
+ Feed'st thy light's flame with self-substantial fuel,
7
+ Making a famine where abundance lies,
8
+ Thy self thy foe, to thy sweet self too cruel:
9
+ Thou that art now the world's fresh ornament,
10
+ And only herald to the gaudy spring,
11
+ Within thine own bud buriest thy content,
12
+ And tender churl mak'st waste in niggarding:
13
+ Pity the world, or else this glutton be,
14
+ To eat the world's due, by the grave and thee.
finetune.ipynb ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a41f141c-b6a8-40d1-b72d-127d028c0592",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
12
+ "\n",
13
+ "model_path = os.getcwd()\n",
14
+ "print(model_path)\n",
15
+ "tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)\n",
16
+ "model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors=True, local_files_only=True)"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "93e9ec6a-4a57-484f-a1a5-ecb6674e8f77",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "#inputs = tokenizer('', return_tensors=\"pt\")\n",
27
+ "#outputs = model.generate(inputs['input_ids'], max_new_tokens=20, temperature=0)\n",
28
+ "#print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "id": "e570b6db-efa8-4c9f-ac71-573479b00711",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "model.gradient_checkpointing_enable()"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "9345e74b-5bef-4cc9-982e-342af69b290a",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "from peft import LoraConfig\n",
49
+ "\n",
50
+ "config = LoraConfig(\n",
51
+ " r=32,\n",
52
+ " lora_alpha=64,\n",
53
+ " target_modules=[\n",
54
+ " \"q_proj\",\n",
55
+ " \"k_proj\",\n",
56
+ " \"v_proj\",\n",
57
+ " \"o_proj\",\n",
58
+ " \"w1\",\n",
59
+ " \"w2\",\n",
60
+ " \"w3\",\n",
61
+ " \"lm_head\",\n",
62
+ " ],\n",
63
+ " bias=\"none\",\n",
64
+ " lora_dropout=0.05, # Conventional\n",
65
+ " task_type=\"CAUSAL_LM\",\n",
66
+ ")\n",
67
+ "\n",
68
+ "#print(model)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "09dd4848-9c7a-4a3b-9887-59652c915cc3",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "import transformers\n",
79
+ "from datetime import datetime\n",
80
+ "\n",
81
+ "project = \"moe_shakespeare15M\"\n",
82
+ "run_name = project\n",
83
+ "output_dir = \"./\" + run_name\n",
84
+ "\n",
85
+ "with open(\"data.txt\", \"r\") as f:\n",
86
+ " content = f.read()\n",
87
+ " tokenized_train_dataset = [\n",
88
+ " tokenizer(content)['input_ids']\n",
89
+ " ]\n",
90
+ "\n",
91
+ "trainer = transformers.Trainer(\n",
92
+ " model=model,\n",
93
+ " train_dataset=tokenized_train_dataset,\n",
94
+ " args=transformers.TrainingArguments(\n",
95
+ " output_dir=output_dir,\n",
96
+ " warmup_steps=10,\n",
97
+ " per_device_train_batch_size=2,\n",
98
+ " gradient_accumulation_steps=1,\n",
99
+ " gradient_checkpointing=True,\n",
100
+ " max_steps=300,\n",
101
+ " learning_rate=2.5e-5, # Want a small lr for finetuning\n",
102
+ " # fp16=True, \n",
103
+ " optim=\"paged_adamw_8bit\",\n",
104
+ " # logging_steps=25, # When to start reporting loss\n",
105
+ " # logging_dir=\"./logs\", # Directory for storing logs\n",
106
+ " save_strategy=\"steps\", # Save the model checkpoint every logging step\n",
107
+ " save_steps=50, # Save checkpoints every 50 steps\n",
108
+ " # evaluation_strategy=\"steps\", # Evaluate the model every logging step\n",
109
+ " # eval_steps=25, # Evaluate and save checkpoints every 50 steps\n",
110
+ " # do_eval=True, # Perform evaluation at the end of training\n",
111
+ " report_to=\"none\", # Comment this out if you don't want to use weights & baises\n",
112
+ " run_name=f\"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\" # Name of the W&B run (optional)\n",
113
+ " ),\n",
114
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
115
+ ")\n",
116
+ "\n",
117
+ "model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n",
118
+ "trainer.train()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "7f0ad783-3f3e-4812-bc4e-026f9aad1435",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": []
128
+ }
129
+ ],
130
+ "metadata": {
131
+ "kernelspec": {
132
+ "display_name": "Python 3 (ipykernel)",
133
+ "language": "python",
134
+ "name": "python3"
135
+ },
136
+ "language_info": {
137
+ "codemirror_mode": {
138
+ "name": "ipython",
139
+ "version": 3
140
+ },
141
+ "file_extension": ".py",
142
+ "mimetype": "text/x-python",
143
+ "name": "python",
144
+ "nbconvert_exporter": "python",
145
+ "pygments_lexer": "ipython3",
146
+ "version": "3.10.12"
147
+ }
148
+ },
149
+ "nbformat": 4,
150
+ "nbformat_minor": 5
151
+ }
generate_moe.ipynb ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "66851a9c-d852-4a25-8cc7-1b7c03d1b3c2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from safetensors.torch import load_file\n",
11
+ "import torch\n",
12
+ "\n",
13
+ "model = load_file(\"model_original.safetensors\", device=\"cpu\")"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 6,
19
+ "id": "6775e2ae-a543-401d-9f81-c450f3eb5910",
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "name": "stdout",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "model.embed_tokens.weight\n",
27
+ "model.layers.0.input_layernorm.weight\n",
28
+ "model.layers.0.mlp.down_proj.weight\n",
29
+ "model.layers.0.mlp.gate_proj.weight\n",
30
+ "model.layers.0.mlp.up_proj.weight\n",
31
+ "model.layers.0.post_attention_layernorm.weight\n",
32
+ "model.layers.0.self_attn.k_proj.weight\n",
33
+ "model.layers.0.self_attn.o_proj.weight\n",
34
+ "model.layers.0.self_attn.q_proj.weight\n",
35
+ "model.layers.0.self_attn.v_proj.weight\n",
36
+ "model.layers.1.input_layernorm.weight\n",
37
+ "model.layers.1.mlp.down_proj.weight\n",
38
+ "model.layers.1.mlp.gate_proj.weight\n",
39
+ "model.layers.1.mlp.up_proj.weight\n",
40
+ "model.layers.1.post_attention_layernorm.weight\n",
41
+ "model.layers.1.self_attn.k_proj.weight\n",
42
+ "model.layers.1.self_attn.o_proj.weight\n",
43
+ "model.layers.1.self_attn.q_proj.weight\n",
44
+ "model.layers.1.self_attn.v_proj.weight\n",
45
+ "model.layers.2.input_layernorm.weight\n",
46
+ "model.layers.2.mlp.down_proj.weight\n",
47
+ "model.layers.2.mlp.gate_proj.weight\n",
48
+ "model.layers.2.mlp.up_proj.weight\n",
49
+ "model.layers.2.post_attention_layernorm.weight\n",
50
+ "model.layers.2.self_attn.k_proj.weight\n",
51
+ "model.layers.2.self_attn.o_proj.weight\n",
52
+ "model.layers.2.self_attn.q_proj.weight\n",
53
+ "model.layers.2.self_attn.v_proj.weight\n",
54
+ "model.layers.3.input_layernorm.weight\n",
55
+ "model.layers.3.mlp.down_proj.weight\n",
56
+ "model.layers.3.mlp.gate_proj.weight\n",
57
+ "model.layers.3.mlp.up_proj.weight\n",
58
+ "model.layers.3.post_attention_layernorm.weight\n",
59
+ "model.layers.3.self_attn.k_proj.weight\n",
60
+ "model.layers.3.self_attn.o_proj.weight\n",
61
+ "model.layers.3.self_attn.q_proj.weight\n",
62
+ "model.layers.3.self_attn.v_proj.weight\n",
63
+ "model.layers.4.input_layernorm.weight\n",
64
+ "model.layers.4.mlp.down_proj.weight\n",
65
+ "model.layers.4.mlp.gate_proj.weight\n",
66
+ "model.layers.4.mlp.up_proj.weight\n",
67
+ "model.layers.4.post_attention_layernorm.weight\n",
68
+ "model.layers.4.self_attn.k_proj.weight\n",
69
+ "model.layers.4.self_attn.o_proj.weight\n",
70
+ "model.layers.4.self_attn.q_proj.weight\n",
71
+ "model.layers.4.self_attn.v_proj.weight\n",
72
+ "model.layers.5.input_layernorm.weight\n",
73
+ "model.layers.5.mlp.down_proj.weight\n",
74
+ "model.layers.5.mlp.gate_proj.weight\n",
75
+ "model.layers.5.mlp.up_proj.weight\n",
76
+ "model.layers.5.post_attention_layernorm.weight\n",
77
+ "model.layers.5.self_attn.k_proj.weight\n",
78
+ "model.layers.5.self_attn.o_proj.weight\n",
79
+ "model.layers.5.self_attn.q_proj.weight\n",
80
+ "model.layers.5.self_attn.v_proj.weight\n",
81
+ "model.norm.weight\n"
82
+ ]
83
+ }
84
+ ],
85
+ "source": [
86
+ "for name, tensor in model.items():\n",
87
+ " print(name)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 25,
93
+ "id": "8b06f3c7-927d-4148-950c-5e1c93a54b75",
94
+ "metadata": {},
95
+ "outputs": [
96
+ {
97
+ "name": "stdout",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "model.embed_tokens.weight torch.Size([32000, 288])\n",
101
+ "model.norm.weight torch.Size([288])\n",
102
+ "lm_head.weight torch.Size([32000, 288])\n",
103
+ "model.layers.0.input_layernorm.weight torch.Size([288])\n",
104
+ "model.layers.0.post_attention_layernorm.weight torch.Size([288])\n",
105
+ "model.layers.0.self_attn.k_proj.weight torch.Size([288, 288])\n",
106
+ "model.layers.0.self_attn.o_proj.weight torch.Size([288, 288])\n",
107
+ "model.layers.0.self_attn.q_proj.weight torch.Size([288, 288])\n",
108
+ "model.layers.0.self_attn.v_proj.weight torch.Size([288, 288])\n",
109
+ "model.layers.0.block_sparse_moe.gate.weight torch.Size([4, 288])\n",
110
+ "model.layers.0.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n",
111
+ "model.layers.0.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n",
112
+ "model.layers.0.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n",
113
+ "model.layers.0.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n",
114
+ "model.layers.0.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n",
115
+ "model.layers.0.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n",
116
+ "model.layers.0.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n",
117
+ "model.layers.0.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n",
118
+ "model.layers.0.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n",
119
+ "model.layers.0.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n",
120
+ "model.layers.0.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n",
121
+ "model.layers.0.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n",
122
+ "model.layers.1.input_layernorm.weight torch.Size([288])\n",
123
+ "model.layers.1.post_attention_layernorm.weight torch.Size([288])\n",
124
+ "model.layers.1.self_attn.k_proj.weight torch.Size([288, 288])\n",
125
+ "model.layers.1.self_attn.o_proj.weight torch.Size([288, 288])\n",
126
+ "model.layers.1.self_attn.q_proj.weight torch.Size([288, 288])\n",
127
+ "model.layers.1.self_attn.v_proj.weight torch.Size([288, 288])\n",
128
+ "model.layers.1.block_sparse_moe.gate.weight torch.Size([4, 288])\n",
129
+ "model.layers.1.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n",
130
+ "model.layers.1.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n",
131
+ "model.layers.1.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n",
132
+ "model.layers.1.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n",
133
+ "model.layers.1.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n",
134
+ "model.layers.1.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n",
135
+ "model.layers.1.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n",
136
+ "model.layers.1.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n",
137
+ "model.layers.1.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n",
138
+ "model.layers.1.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n",
139
+ "model.layers.1.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n",
140
+ "model.layers.1.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n",
141
+ "model.layers.2.input_layernorm.weight torch.Size([288])\n",
142
+ "model.layers.2.post_attention_layernorm.weight torch.Size([288])\n",
143
+ "model.layers.2.self_attn.k_proj.weight torch.Size([288, 288])\n",
144
+ "model.layers.2.self_attn.o_proj.weight torch.Size([288, 288])\n",
145
+ "model.layers.2.self_attn.q_proj.weight torch.Size([288, 288])\n",
146
+ "model.layers.2.self_attn.v_proj.weight torch.Size([288, 288])\n",
147
+ "model.layers.2.block_sparse_moe.gate.weight torch.Size([4, 288])\n",
148
+ "model.layers.2.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n",
149
+ "model.layers.2.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n",
150
+ "model.layers.2.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n",
151
+ "model.layers.2.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n",
152
+ "model.layers.2.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n",
153
+ "model.layers.2.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n",
154
+ "model.layers.2.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n",
155
+ "model.layers.2.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n",
156
+ "model.layers.2.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n",
157
+ "model.layers.2.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n",
158
+ "model.layers.2.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n",
159
+ "model.layers.2.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n",
160
+ "model.layers.3.input_layernorm.weight torch.Size([288])\n",
161
+ "model.layers.3.post_attention_layernorm.weight torch.Size([288])\n",
162
+ "model.layers.3.self_attn.k_proj.weight torch.Size([288, 288])\n",
163
+ "model.layers.3.self_attn.o_proj.weight torch.Size([288, 288])\n",
164
+ "model.layers.3.self_attn.q_proj.weight torch.Size([288, 288])\n",
165
+ "model.layers.3.self_attn.v_proj.weight torch.Size([288, 288])\n",
166
+ "model.layers.3.block_sparse_moe.gate.weight torch.Size([4, 288])\n",
167
+ "model.layers.3.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n",
168
+ "model.layers.3.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n",
169
+ "model.layers.3.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n",
170
+ "model.layers.3.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n",
171
+ "model.layers.3.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n",
172
+ "model.layers.3.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n",
173
+ "model.layers.3.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n",
174
+ "model.layers.3.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n",
175
+ "model.layers.3.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n",
176
+ "model.layers.3.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n",
177
+ "model.layers.3.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n",
178
+ "model.layers.3.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n",
179
+ "model.layers.4.input_layernorm.weight torch.Size([288])\n",
180
+ "model.layers.4.post_attention_layernorm.weight torch.Size([288])\n",
181
+ "model.layers.4.self_attn.k_proj.weight torch.Size([288, 288])\n",
182
+ "model.layers.4.self_attn.o_proj.weight torch.Size([288, 288])\n",
183
+ "model.layers.4.self_attn.q_proj.weight torch.Size([288, 288])\n",
184
+ "model.layers.4.self_attn.v_proj.weight torch.Size([288, 288])\n",
185
+ "model.layers.4.block_sparse_moe.gate.weight torch.Size([4, 288])\n",
186
+ "model.layers.4.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n",
187
+ "model.layers.4.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n",
188
+ "model.layers.4.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n",
189
+ "model.layers.4.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n",
190
+ "model.layers.4.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n",
191
+ "model.layers.4.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n",
192
+ "model.layers.4.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n",
193
+ "model.layers.4.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n",
194
+ "model.layers.4.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n",
195
+ "model.layers.4.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n",
196
+ "model.layers.4.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n",
197
+ "model.layers.4.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n",
198
+ "model.layers.5.input_layernorm.weight torch.Size([288])\n",
199
+ "model.layers.5.post_attention_layernorm.weight torch.Size([288])\n",
200
+ "model.layers.5.self_attn.k_proj.weight torch.Size([288, 288])\n",
201
+ "model.layers.5.self_attn.o_proj.weight torch.Size([288, 288])\n",
202
+ "model.layers.5.self_attn.q_proj.weight torch.Size([288, 288])\n",
203
+ "model.layers.5.self_attn.v_proj.weight torch.Size([288, 288])\n",
204
+ "model.layers.5.block_sparse_moe.gate.weight torch.Size([4, 288])\n",
205
+ "model.layers.5.block_sparse_moe.experts.0.w1.weight torch.Size([768, 288])\n",
206
+ "model.layers.5.block_sparse_moe.experts.0.w2.weight torch.Size([288, 768])\n",
207
+ "model.layers.5.block_sparse_moe.experts.0.w3.weight torch.Size([768, 288])\n",
208
+ "model.layers.5.block_sparse_moe.experts.1.w1.weight torch.Size([768, 288])\n",
209
+ "model.layers.5.block_sparse_moe.experts.1.w2.weight torch.Size([288, 768])\n",
210
+ "model.layers.5.block_sparse_moe.experts.1.w3.weight torch.Size([768, 288])\n",
211
+ "model.layers.5.block_sparse_moe.experts.2.w1.weight torch.Size([768, 288])\n",
212
+ "model.layers.5.block_sparse_moe.experts.2.w2.weight torch.Size([288, 768])\n",
213
+ "model.layers.5.block_sparse_moe.experts.2.w3.weight torch.Size([768, 288])\n",
214
+ "model.layers.5.block_sparse_moe.experts.3.w1.weight torch.Size([768, 288])\n",
215
+ "model.layers.5.block_sparse_moe.experts.3.w2.weight torch.Size([288, 768])\n",
216
+ "model.layers.5.block_sparse_moe.experts.3.w3.weight torch.Size([768, 288])\n"
217
+ ]
218
+ }
219
+ ],
220
+ "source": [
221
+ "N_EXPERTS = 4\n",
222
+ "N_LAYERS = 6\n",
223
+ "N_FF = 768\n",
224
+ "N_EMBD = 288\n",
225
+ "\n",
226
+ "moe_model = dict()\n",
227
+ "def copy_tensor(name, new_name = None):\n",
228
+ " new_name = name if new_name is None else new_name\n",
229
+ " moe_model[new_name] = torch.clone(model[name])\n",
230
+ "\n",
231
+ "copy_tensor('model.embed_tokens.weight')\n",
232
+ "copy_tensor('model.norm.weight')\n",
233
+ "copy_tensor('model.embed_tokens.weight', 'lm_head.weight')\n",
234
+ "\n",
235
+ "torch.manual_seed(0)\n",
236
+ "for il in range(N_LAYERS):\n",
237
+ " copy_tensor(f'model.layers.{il}.input_layernorm.weight')\n",
238
+ " copy_tensor(f'model.layers.{il}.post_attention_layernorm.weight')\n",
239
+ " copy_tensor(f'model.layers.{il}.self_attn.k_proj.weight')\n",
240
+ " copy_tensor(f'model.layers.{il}.self_attn.o_proj.weight')\n",
241
+ " copy_tensor(f'model.layers.{il}.self_attn.q_proj.weight')\n",
242
+ " copy_tensor(f'model.layers.{il}.self_attn.v_proj.weight')\n",
243
+ " moe_model[f'model.layers.{il}.block_sparse_moe.gate.weight'] = torch.rand(N_EXPERTS, N_EMBD)\n",
244
+ " for ex in range(N_EXPERTS):\n",
245
+ " copy_tensor(f'model.layers.{il}.mlp.gate_proj.weight', f'model.layers.{il}.block_sparse_moe.experts.{ex}.w1.weight')\n",
246
+ " copy_tensor(f'model.layers.{il}.mlp.down_proj.weight', f'model.layers.{il}.block_sparse_moe.experts.{ex}.w2.weight')\n",
247
+ " copy_tensor(f'model.layers.{il}.mlp.up_proj.weight', f'model.layers.{il}.block_sparse_moe.experts.{ex}.w3.weight')\n",
248
+ "\n",
249
+ "for name, tensor in moe_model.items():\n",
250
+ " print(name, tensor.shape)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 26,
256
+ "id": "19817bec-448f-4619-8772-2b3c77f0a1c2",
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "from safetensors.torch import save_file\n",
261
+ "\n",
262
+ "save_file(moe_model, \"model.safetensors\", metadata={\"format\": \"pt\"})"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "id": "e5bfd2cb-f53b-4285-bf5d-52a6c23779e0",
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": []
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 15,
276
+ "id": "e29a4b7e-e390-4d69-857c-02fc6065e33d",
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "import os\n",
281
+ "import json\n",
282
+ "\n",
283
+ "index_json = {\n",
284
+ " \"metadata\": {\n",
285
+ " \"total_size\": os.path.getsize(\"model.safetensors\"),\n",
286
+ " \"format\": \"safetensors\"\n",
287
+ " },\n",
288
+ " \"weight_map\": {}\n",
289
+ "}\n",
290
+ "\n",
291
+ "for name, _ in moe_model.items():\n",
292
+ " index_json[\"weight_map\"][name] = \"model.safetensors\"\n",
293
+ "\n",
294
+ "#with open(\"model.safetensors.index.json\", 'w') as json_file:\n",
295
+ "# json.dump(index_json, json_file, indent=2)"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "id": "c7e0736c-0139-4808-8943-c9eba5dcfc76",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": []
305
+ }
306
+ ],
307
+ "metadata": {
308
+ "kernelspec": {
309
+ "display_name": "Python 3 (ipykernel)",
310
+ "language": "python",
311
+ "name": "python3"
312
+ },
313
+ "language_info": {
314
+ "codemirror_mode": {
315
+ "name": "ipython",
316
+ "version": 3
317
+ },
318
+ "file_extension": ".py",
319
+ "mimetype": "text/x-python",
320
+ "name": "python",
321
+ "nbconvert_exporter": "python",
322
+ "pygments_lexer": "ipython3",
323
+ "version": "3.10.12"
324
+ }
325
+ },
326
+ "nbformat": 4,
327
+ "nbformat_minor": 5
328
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.36.0.dev0"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbfa0289f68a8dd721d10eb12d8bd82e098455682027f6f9986ba548913f9082
3
+ size 72744704
model_original.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e8d4614e24c89e99502000294af0aab73e9266029357377578e0a504b7f8d9
3
+ size 30389560
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "__type": "AddedToken",
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "clean_up_tokenization_spaces": false,
11
+ "eos_token": {
12
+ "__type": "AddedToken",
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "model_max_length": 2048,
20
+ "pad_token": null,
21
+ "sp_model_kwargs": {},
22
+ "tokenizer_class": "LlamaTokenizer",
23
+ "unk_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<unk>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "use_default_system_prompt": true
32
+ }