diff --git "a/llama-merge-delta.ipynb" "b/llama-merge-delta.ipynb" new file mode 100644--- /dev/null +++ "b/llama-merge-delta.ipynb" @@ -0,0 +1,9043 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Einstein-v6.1-Llama3-8B',\n", + " 'L3-TheSpice-8b-v0.8.3',\n", + " 'dolphin-2.9-llama3-8b',\n", + " 'Configurable-Hermes-2-Pro-Llama-3-8B',\n", + " 'MAmmoTH2-8B-Plus',\n", + " 'badger',\n", + " 'Pantheon-RP-1.0-8b-Llama-3',\n", + " 'Tiamat-8b-1.2-Llama-3-DPO',\n", + " 'Buzz-8b-Large-v0.5',\n", + " 'Kei_Llama3_8B',\n", + " 'Llama-3-Lumimaid-8B-v0.1',\n", + " 'llama-3-cat-8b-instruct-pytorch',\n", + " 'Llama-3SOME-8B-v1',\n", + " 'Roleplay-Llama-3-8B',\n", + " 'Llama-3-LewdPlay-8B-evo',\n", + " 'meta-llama-3-8b-instruct-hf-ortho-baukit-5fail-3000total-bf16',\n", + " 'Llama-3-8B-Instruct-Coder',\n", + " 'opus-v1.2-llama-3-8b-instruct-run3.5-epoch2.5',\n", + " 'Poppy_Porpoise-0.72-L3-8B',\n", + " 'Llama-3-8B-Instruct-norefusal',\n", + " 'Meta-Llama-3-8B-Instruct-DPO',\n", + " 'meta-llama-3-8b-instruct-hf-ortho-baukit-34fail-3000total-bf16',\n", + " 'llama-3-lumimaid-habib-v7',\n", + " 'Llama-3-Refueled',\n", + " 'Llama-3-8B-Instruct-DPO-v0.4',\n", + " 'Llama-3-8B-Instruct-Gradient-1048k',\n", + " 'Mahou-1.0-llama3-8B',\n", + " 'Llama-3-SauerkrautLM-8b-Instruct',\n", + " 'Llama-3-Soliloquy-8B-v2']" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from glob import glob\n", + "\n", + "import os\n", + "\n", + "[ os.path.basename(f) for f in glob('../../llama/raw/delta/**/*')]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../mergekit\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from ztrain.io import list_for_path\n", + "\n", + "models = [\n", + " 'Einstein-v6.1-Llama3-8B',\n", + " 'L3-TheSpice-8b-v0.8.3',\n", + " 'dolphin-2.9-llama3-8b',\n", + " 'Configurable-Hermes-2-Pro-Llama-3-8B',\n", + " 'MAmmoTH2-8B-Plus',\n", + " 'Pantheon-RP-1.0-8b-Llama-3',\n", + " 'Tiamat-8b-1.2-Llama-3-DPO',\n", + " 'Buzz-8b-Large-v0.5',\n", + " 'Kei_Llama3_8B',\n", + " 'Llama-3-Lumimaid-8B-v0.1',\n", + " 'llama-3-cat-8b-instruct-pytorch',\n", + " 'Llama-3SOME-8B-v1',\n", + " 'Roleplay-Llama-3-8B',\n", + " 'Llama-3-LewdPlay-8B-evo',\n", + " #'Llama-3-8B-Instruct-Coder',\n", + " 'opus-v1.2-llama-3-8b-instruct-run3.5-epoch2.5',\n", + " 'meta-llama-3-8b-instruct-hf-ortho-baukit-5fail-3000total-bf16',\n", + " 'Poppy_Porpoise-0.72-L3-8B',\n", + " 'Llama-3-8B-Instruct-norefusal',\n", + " 'Meta-Llama-3-8B-Instruct-DPO',\n", + " 'badger',\n", + " #'llama-3-lumimaid-habib-v7',\n", + " 'Llama-3-Refueled',\n", + " 'Llama-3-8B-Instruct-DPO-v0.4',\n", + " 'Llama-3-8B-Instruct-Gradient-1048k',\n", + " 'Mahou-1.0-llama3-8B',\n", + " 'Llama-3-SauerkrautLM-8b-Instruct',\n", + " 'Llama-3-Soliloquy-8B-v2'\n", + "]\n", + "\n", + "out_model = \"opus-v1.2-llama-3-8b-instruct-run3.5-epoch2.5\"\n", + "in_model = \"Llama-3-8B-Instruct-Gradient-1048k\"\n", + "model_names, subtypes, model_list, group_idx = list_for_path(path=\"../../llama/raw/delta/\", include_folders=models, search=\"/**/*\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "594cae0000f54f1eb06184da18f6a46f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 embedding_size:\n", + " # we need to resize the embeddings\n", + " print(f\"Resizing embeddings from {m1.shape[0]} to {embedding_size}. If this is in an input layer, it may be catastrophic\")\n", + " m1 = m1[:embedding_size]\n", + " if m1.shape[0] < embedding_size:\n", + " print(f\"Padding embeddings from {m1.shape[0]} to {embedding_size}. If this is in an input layer, it may be catastrophic\")\n", + " m1 = torch.cat([m1, torch.zeros(embedding_size - m1.shape[0], m1.shape[1])], dim=0)\n", + " elif m1.shape != m0.shape:\n", + " print(f\"ismatch {m} for {k} {m1.shape} {m0.shape}\")\n", + " continue\n", + " if subtype == 'base':\n", + " m1 = m1 + mi\n", + " \n", + " m1 = m1 - m0\n", + " m1_norm = torch.norm(m1.to(torch.float32)).item()\n", + " mean_norms.append(m1_norm)\n", + " layer_stack.append([f\"{m}-{i}\", m1.to(cold)])\n", + " del m1\n", + " m0 = m0.to(cold)\n", + " del mi\n", + "\n", + " mean_norm = torch.tensor(mean_norms).median().item()\n", + " del mean_norms\n", + " \n", + " correlation = correlate_pairs(torch.stack([s[1] for s in layer_stack], dim=0))\n", + " while len(layer_stack) > 1:\n", + " next_stack = []\n", + " for x, y, corr in least_correlated_pairs(correlation):\n", + " if y >= 0:\n", + " a = layer_stack[x][1].to(primary)\n", + " b = layer_stack[y][1].to(primary)\n", + " a_key = layer_stack[x][0]\n", + " b_key = layer_stack[y][0]\n", + " if len(a_key) < len(b_key):\n", + " a, b = b, a\n", + " a_key, b_key = b_key, a_key\n", + " remaining_keys = [i[0] for i in layer_stack if i[0] not in [a_key, b_key]]\n", + "\n", + " merge, norm_a, norm_b = merge_tensors_fft2_autoscale(a, b, .5)\n", + " merge = merge * mean_norm\n", + " post_norm = torch.norm(merge.to(torch.float32))\n", + " print(f\"Pair {mean_norm:.3f} {post_norm:.3f} {corr:.7f} [ratio: {norm_a:.3f} {norm_b:.3f}] {merge.sum():.3f} ] {a_key} :: {b_key}\") # [left: {', '.join(remaining_keys)}]\")\n", + " del a, b, post_norm\n", + " next_stack.append([f\"{a_key}_{b_key}\", merge.to(cold).detach()])\n", + " else:\n", + " next_stack.append(layer_stack[x])\n", + " print(f\"Skipping {layer_stack[x][0]}\")\n", + "\n", + " del layer_stack\n", + " layer_stack = next_stack\n", + " correlation = correlate_pairs(torch.stack([s[1] for s in layer_stack], dim=0))\n", + "\n", + " result = layer_stack[0][1]\n", + " writer.save_tensor(k, (m0.to(work_device) + result.to(work_device)).to(torch.bfloat16), clone=True)\n", + " print(f\"Saving {k} {result.shape} {result.dtype}\")\n", + " del result, m0, mean_norm\n", + " if j % 30 == 0:\n", + " writer.flush_current_shard()\n", + "\n", + "writer.finalize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "out_folder = \"./maldv/badger-augment\"\n", + "in_folder = \"./models/delta/base/badger\"\n", + "tokenizer = AutoTokenizer.from_pretrained(in_folder)\n", + "tokenizer.save_pretrained(out_folder)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jupyterenv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}