{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Einstein-v6.1-Llama3-8B',\n", " 'L3-TheSpice-8b-v0.8.3',\n", " 'dolphin-2.9-llama3-8b',\n", " 'Configurable-Hermes-2-Pro-Llama-3-8B',\n", " 'MAmmoTH2-8B-Plus',\n", " 'badger',\n", " 'Pantheon-RP-1.0-8b-Llama-3',\n", " 'Tiamat-8b-1.2-Llama-3-DPO',\n", " 'Buzz-8b-Large-v0.5',\n", " 'Kei_Llama3_8B',\n", " 'Llama-3-Lumimaid-8B-v0.1',\n", " 'llama-3-cat-8b-instruct-pytorch',\n", " 'Llama-3SOME-8B-v1',\n", " 'Roleplay-Llama-3-8B',\n", " 'Llama-3-LewdPlay-8B-evo',\n", " 'meta-llama-3-8b-instruct-hf-ortho-baukit-5fail-3000total-bf16',\n", " 'Llama-3-8B-Instruct-Coder',\n", " 'opus-v1.2-llama-3-8b-instruct-run3.5-epoch2.5',\n", " 'Poppy_Porpoise-0.72-L3-8B',\n", " 'Llama-3-8B-Instruct-norefusal',\n", " 'Meta-Llama-3-8B-Instruct-DPO',\n", " 'meta-llama-3-8b-instruct-hf-ortho-baukit-34fail-3000total-bf16',\n", " 'llama-3-lumimaid-habib-v7',\n", " 'Llama-3-Refueled',\n", " 'Llama-3-8B-Instruct-DPO-v0.4',\n", " 'Llama-3-8B-Instruct-Gradient-1048k',\n", " 'Mahou-1.0-llama3-8B',\n", " 'Llama-3-SauerkrautLM-8b-Instruct',\n", " 'Llama-3-Soliloquy-8B-v2']" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from glob import glob\n", "\n", "import os\n", "\n", "[ os.path.basename(f) for f in glob('../../llama/raw/delta/**/*')]" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../mergekit\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from ztrain.io import list_for_path\n", "\n", "models = [\n", " 'Einstein-v6.1-Llama3-8B',\n", " 'L3-TheSpice-8b-v0.8.3',\n", " 'dolphin-2.9-llama3-8b',\n", " 'Configurable-Hermes-2-Pro-Llama-3-8B',\n", " 'MAmmoTH2-8B-Plus',\n", " 'Pantheon-RP-1.0-8b-Llama-3',\n", " 'Tiamat-8b-1.2-Llama-3-DPO',\n", " 'Buzz-8b-Large-v0.5',\n", " 'Kei_Llama3_8B',\n", " 'Llama-3-Lumimaid-8B-v0.1',\n", " 'llama-3-cat-8b-instruct-pytorch',\n", " 'Llama-3SOME-8B-v1',\n", " 'Roleplay-Llama-3-8B',\n", " 'Llama-3-LewdPlay-8B-evo',\n", " #'Llama-3-8B-Instruct-Coder',\n", " 'opus-v1.2-llama-3-8b-instruct-run3.5-epoch2.5',\n", " 'meta-llama-3-8b-instruct-hf-ortho-baukit-5fail-3000total-bf16',\n", " 'Poppy_Porpoise-0.72-L3-8B',\n", " 'Llama-3-8B-Instruct-norefusal',\n", " 'Meta-Llama-3-8B-Instruct-DPO',\n", " 'badger',\n", " #'llama-3-lumimaid-habib-v7',\n", " 'Llama-3-Refueled',\n", " 'Llama-3-8B-Instruct-DPO-v0.4',\n", " 'Llama-3-8B-Instruct-Gradient-1048k',\n", " 'Mahou-1.0-llama3-8B',\n", " 'Llama-3-SauerkrautLM-8b-Instruct',\n", " 'Llama-3-Soliloquy-8B-v2'\n", "]\n", "\n", "out_model = \"opus-v1.2-llama-3-8b-instruct-run3.5-epoch2.5\"\n", "in_model = \"Llama-3-8B-Instruct-Gradient-1048k\"\n", "model_names, subtypes, model_list, group_idx = list_for_path(path=\"../../llama/raw/delta/\", include_folders=models, search=\"/**/*\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "594cae0000f54f1eb06184da18f6a46f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00 embedding_size:\n", " # we need to resize the embeddings\n", " print(f\"Resizing embeddings from {m1.shape[0]} to {embedding_size}. If this is in an input layer, it may be catastrophic\")\n", " m1 = m1[:embedding_size]\n", " if m1.shape[0] < embedding_size:\n", " print(f\"Padding embeddings from {m1.shape[0]} to {embedding_size}. If this is in an input layer, it may be catastrophic\")\n", " m1 = torch.cat([m1, torch.zeros(embedding_size - m1.shape[0], m1.shape[1])], dim=0)\n", " elif m1.shape != m0.shape:\n", " print(f\"ismatch {m} for {k} {m1.shape} {m0.shape}\")\n", " continue\n", " if subtype == 'base':\n", " m1 = m1 + mi\n", " \n", " m1 = m1 - m0\n", " m1_norm = torch.norm(m1.to(torch.float32)).item()\n", " mean_norms.append(m1_norm)\n", " layer_stack.append([f\"{m}-{i}\", m1.to(cold)])\n", " del m1\n", " m0 = m0.to(cold)\n", " del mi\n", "\n", " mean_norm = torch.tensor(mean_norms).median().item()\n", " del mean_norms\n", " \n", " correlation = correlate_pairs(torch.stack([s[1] for s in layer_stack], dim=0))\n", " while len(layer_stack) > 1:\n", " next_stack = []\n", " for x, y, corr in least_correlated_pairs(correlation):\n", " if y >= 0:\n", " a = layer_stack[x][1].to(primary)\n", " b = layer_stack[y][1].to(primary)\n", " a_key = layer_stack[x][0]\n", " b_key = layer_stack[y][0]\n", " if len(a_key) < len(b_key):\n", " a, b = b, a\n", " a_key, b_key = b_key, a_key\n", " remaining_keys = [i[0] for i in layer_stack if i[0] not in [a_key, b_key]]\n", "\n", " merge, norm_a, norm_b = merge_tensors_fft2_autoscale(a, b, .5)\n", " merge = merge * mean_norm\n", " post_norm = torch.norm(merge.to(torch.float32))\n", " print(f\"Pair {mean_norm:.3f} {post_norm:.3f} {corr:.7f} [ratio: {norm_a:.3f} {norm_b:.3f}] {merge.sum():.3f} ] {a_key} :: {b_key}\") # [left: {', '.join(remaining_keys)}]\")\n", " del a, b, post_norm\n", " next_stack.append([f\"{a_key}_{b_key}\", merge.to(cold).detach()])\n", " else:\n", " next_stack.append(layer_stack[x])\n", " print(f\"Skipping {layer_stack[x][0]}\")\n", "\n", " del layer_stack\n", " layer_stack = next_stack\n", " correlation = correlate_pairs(torch.stack([s[1] for s in layer_stack], dim=0))\n", "\n", " result = layer_stack[0][1]\n", " writer.save_tensor(k, (m0.to(work_device) + result.to(work_device)).to(torch.bfloat16), clone=True)\n", " print(f\"Saving {k} {result.shape} {result.dtype}\")\n", " del result, m0, mean_norm\n", " if j % 30 == 0:\n", " writer.flush_current_shard()\n", "\n", "writer.finalize()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "out_folder = \"./maldv/badger-augment\"\n", "in_folder = \"./models/delta/base/badger\"\n", "tokenizer = AutoTokenizer.from_pretrained(in_folder)\n", "tokenizer.save_pretrained(out_folder)" ] } ], "metadata": { "kernelspec": { "display_name": "jupyterenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }