{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "82acAhWYGIPx" }, "source": [ "# Demo of bypassing refusal" ] }, { "cell_type": "markdown", "metadata": { "id": "j7hOtw7UHXdD" }, "source": [ "This notebook demonstrates our method for bypassing refusal, levaraging the insight that refusal is mediated by a 1-dimensional subspace. We recommend reading the [research post](https://www.lesswrong.com/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction) for a more thorough explanation.\n", "\n", "Modified by FailSpy for easier usage" ] }, { "cell_type": "markdown", "metadata": { "id": "fcxHyDZw6b86" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "dLeei4-T6Wef" }, "outputs": [], "source": [ "%%capture\n", "!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama scikit-learn" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "_vhhwl-2-jPg" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import functools\n", "import einops\n", "import requests\n", "import pandas as pd\n", "import io\n", "import textwrap\n", "import gc\n", "import json\n", "\n", "from datasets import load_dataset\n", "from sklearn.model_selection import train_test_split\n", "from tqdm import tqdm\n", "from torch import Tensor\n", "from typing import List, Callable\n", "from transformer_lens import HookedTransformer, utils\n", "from transformer_lens.hook_points import HookPoint\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "from jaxtyping import Float, Int\n", "from colorama import Fore\n", "\n", "# We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training. (credit: Undi95)\n", "torch.set_grad_enabled(False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load jp/harmless prompts datasets, and chat template\n", "Please note you may need to change the `CHAT_TEMPLATE` for whichever model you're using" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "Rth8yvLZJsXs" }, "outputs": [], "source": [ "def get_jp_instructions(file_path):\n", " with open(file_path) as file:\n", " dataset = json.load(file)\n", "\n", " instructions = [item['question'] for item in dataset]\n", "\n", " train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n", "\n", " return train, test\n", "\n", "def get_harmless_instructions():\n", " hf_path = 'tatsu-lab/alpaca'\n", " dataset = load_dataset(hf_path)\n", "\n", " # filter for instructions that do not have inputs\n", " instructions = []\n", " for i in range(len(dataset['train'])):\n", " if dataset['train'][i]['input'].strip() == '':\n", " instructions.append(dataset['train'][i]['instruction'])\n", "\n", " train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n", "\n", " return train, test\n", "\n", "file_path = 'JEOPARDY_QUESTIONS1.json'\n", "jp_inst_train, jp_inst_test = get_jp_instructions(file_path)\n", "harmless_inst_train, harmless_inst_test = get_harmless_instructions()\n", "\n", "CHAT_TEMPLATE = \"\"\"<|user|>\\n{instruction}<|end|>\\n<|assistant|>\"\"\" # phi-3 chat template\n", "\n", "CHAT_TEMPLATE = \"\"\"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\" # llama-3 chat template\n" ] }, { "cell_type": "markdown", "metadata": { "id": "6ZOoJagxD49V" }, "source": [ "### Load model (multi-GPU support)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 191, "referenced_widgets": [ "ad063e2c68a44f009bfab68c141c09be", "89ee88168c474e9fbcf4a17f1483eff4", "3877270cf4bc42a9b6142cce7a5d8c54", "9a5611a341ed4673aaaf2f463f685d7c", "a2de63dfbd6c485e841c6fcd1fefe451", "c362d50107dd4a2db0d1a79da2af8d57", "ffa85c694b694425999187b346c7ecfe", "ec8f6f360a2243b0ac98d34e825ba378", "f2ee188bfaa84e9680dbc296b1adbef6", "e973493cd6d14381bb4ad2f82417e8a9", "89797f6e82104058af92e3ceb094af66" ] }, "id": "Vnp65Vsg5x-5", "outputId": "25fb5805-fe31-44b0-8f73-6fabc230d261" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e8092f91d75a48d28669d424312b3283", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00 Int[Tensor, 'batch_size seq_len']:\n", " prompts = [CHAT_TEMPLATE.format(instruction=instruction) for instruction in instructions]\n", " return tokenizer(prompts, padding=True, truncation=False, return_tensors=\"pt\").input_ids\n", "\n", "tokenize_instructions_fn = functools.partial(tokenize_instructions_chat, tokenizer=model.tokenizer)\n", "def _generate_with_hooks(\n", " model: HookedTransformer,\n", " toks: Int[Tensor, 'batch_size seq_len'],\n", " max_tokens_generated: int = 64,\n", " fwd_hooks = [],\n", ") -> List[str]:\n", " all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)\n", " all_toks[:, :toks.shape[1]] = toks\n", " for i in range(max_tokens_generated):\n", " with model.hooks(fwd_hooks=fwd_hooks):\n", " logits = model(all_toks[:, :-max_tokens_generated + i])\n", " next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0)\n", " all_toks[:,-max_tokens_generated+i] = next_tokens\n", " return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)\n", "\n", "def get_generations(\n", " model: HookedTransformer,\n", " instructions: List[str],\n", " tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],\n", " fwd_hooks = [],\n", " max_tokens_generated: int = 64,\n", " batch_size: int = 4,\n", ") -> List[str]:\n", " generations = []\n", " for i in tqdm(range(0, len(instructions), batch_size)):\n", " toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size])\n", " generation = _generate_with_hooks(\n", " model,\n", " toks,\n", " max_tokens_generated=max_tokens_generated,\n", " fwd_hooks=fwd_hooks,\n", " )\n", " generations.extend(generation)\n", " return generations\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Useful cell for clearing RAM/VRAM" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# run this cell if you bump into memory issues\n", "try:\n", " del harmless_logits\n", "except Exception:\n", " pass\n", "try:\n", " del jp_logits\n", "except Exception:\n", " pass\n", "gc.collect(); torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": { "id": "W9O8dm0_EQRk" }, "source": [ "## Finding potential \"refusal directions\" (batched)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:57<00:00, 14.43s/it]\n" ] } ], "source": [ "jp = {}\n", "harmless = {}\n", "\n", "# may want to spare your RAM and cycles here. can use '32' here instead or something like the paper\n", "N_INST_TRAIN = 512 \n", "\n", "# load the full training set here to align all the dimensions\n", "toks = tokenize_instructions_fn(instructions=jp_inst_train[:N_INST_TRAIN]+harmless_inst_train[:N_INST_TRAIN])\n", "jp_toks,harmless_toks = toks.split(N_INST_TRAIN)\n", "\n", "batch_size = 128 # adjust this based on available VRAM\n", "\n", "for i in tqdm(range(0, N_INST_TRAIN // batch_size + (N_INST_TRAIN % batch_size > 0))):\n", " id = i*batch_size\n", " e = min(N_INST_TRAIN,id+batch_size)\n", "\n", " # run the models on jp and harmless prompts, cache their activations separately.\n", " jp_logits, jp_cache = model.run_with_cache(jp_toks[id:e], names_filter=lambda hook_name: 'resid' in hook_name, device='cpu', reset_hooks_end=True)\n", " harmless_logits, harmless_cache = model.run_with_cache(harmless_toks[id:e], names_filter=lambda hook_name: 'resid' in hook_name, device='cpu', reset_hooks_end=True)\n", " \n", " for key in jp_cache:\n", " if key not in jp:\n", " jp[key] = [jp_cache[key]]\n", " harmless[key] = [harmless_cache[key]]\n", " else:\n", " jp[key].append(jp_cache[key])\n", " harmless[key].append(harmless_cache[key])\n", "\n", " # force Python & PyTorch to clear GPU and CPU RAM where possible\n", " del jp_logits, harmless_logits, jp_cache, harmless_cache\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", "\n", "jp = {k:torch.cat(v) for k,v in jp.items()}\n", "harmless = {k:torch.cat(v) for k,v in harmless.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute activations into refusal directions" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "tqD5E8Vc_w5d" }, "outputs": [], "source": [ "# compute difference of means between jp and harmless activations at intermediate layers\n", "\n", "def get_act_idx(cache_dict, act_name, layer):\n", " key = (act_name, layer,)\n", " return cache_dict[utils.get_act_name(*key)]\n", "\n", "activation_layers = ['resid_pre', 'resid_mid', 'resid_post']\n", "\n", "activation_refusals = {k:[] for k in activation_layers}\n", "\n", "for layer_num in range(1,model.cfg.n_layers):\n", " pos = -1\n", "\n", " for layer in activation_layers:\n", " jp_mean_act = get_act_idx(jp, layer, layer_num)[:, pos, :].mean(dim=0)\n", " harmless_mean_act = get_act_idx(harmless, layer, layer_num)[:, pos, :].mean(dim=0)\n", " \n", " refusal_dir = jp_mean_act - harmless_mean_act\n", " refusal_dir = refusal_dir / refusal_dir.norm()\n", " activation_refusals[layer].append(refusal_dir)\n", "\n", "# save to file so you don't have to re-build later\n", "torch.save(activation_refusals, 'refusal_dirs.pth')\n", "refusal_dirs = activation_refusals" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Alternatively, load a previous attempt's pth:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "activation_refusals = torch.load('refusal_dirs.pth')\n", "refusal_dirs = activation_refusals" ] }, { "cell_type": "markdown", "metadata": { "id": "2EoxY5i1CWe3" }, "source": [ "## Ablate \"refusal direction\" via inference-time intervention\n", "\n", "Given a \"refusal direction\" $\\widehat{r} \\in \\mathbb{R}^{d_{\\text{model}}}$ with unit norm, we can ablate this direction from the model's activations $a_{l}$:\n", "$${a}_{l}' \\leftarrow a_l - (a_l \\cdot \\widehat{r}) \\widehat{r}$$\n", "\n", "By performing this ablation on all intermediate activations, we enforce that the model can never express this direction (or \"feature\")." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### \"Score\" layer activation diffs\n", "This is a place with room for improvement in methodology. For now, I'm just doing a rudimentary sort based on difference distance average to find dirs with the most \"change\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Get all calculated potential refusal dirs, sort them in Descending order (reverse) based on their mean()\n", "activation_layers = ['resid_pre', 'resid_mid', 'resid_post'] # you can use a subset of these if you don't think certain activations are promising\n", "\n", "activation_layers = ['resid_pre'] # this is usually good enough, though if you've got the compute to spare...\n", "activation_scored = sorted([activation_refusals[layer][l-1] for l in range(1,model.cfg.n_layers) for layer in activation_layers], key = lambda x: abs(x.mean()), reverse=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model ablation testing/brute-forcing the best refusal dir" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Inference-time intervention hook:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "26rf-yncB2PT" }, "outputs": [], "source": [ "def antidirection_ablation_hook(\n", " activation: Float[Tensor, \"... d_act\"],\n", " hook: HookPoint,\n", " direction: Float[Tensor, \"d_act\"],\n", " avg_proj\n", "):\n", " if activation.device != direction.device:\n", " direction = direction.to(activation.device)\n", " if avg_proj.device != activation.device:\n", " avg_proj = avg_proj.to(activation.device)\n", " proj = (einops.einsum(activation, direction.view(-1, 1), '... d_act, d_act single -> ... single') * direction)\n", " return (activation - proj) + (avg_proj*direction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Testing baseline, can skip if you don't care (it will basically just be refusals with a regular model :P)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1/1 [00:06<00:00, 6.33s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "That sounds like an amazing event! The Symphony of Fire in British Columbia sounds like a must-see event for anyone who loves fireworks and live music.assistant\n", "\n", "I completely agree! The Symphony of Fire sounds like an incredible experience that combines the thrill of fireworks with the beauty of live music. It's\n", "\n", "\n", "The answer is Thomas Dewey.assistant\n", "\n", "That's correct! Thomas Dewey was the Republican presidential nominee in 1944 and 1948, and he lost both times. His running mate in 1948 was Earl Warren, who later became the Chief Justice of the United States.\n", "\n", "\n", "A clever clue!\n", "\n", "After some research, I think I have found the answer:\n", "\n", "The alternate title to the Anne Rice book \"Ramses the Damned\" is \"The Vampire Ramses\".\n", "\n", "And, as you hinted, one of the previously buried folk mentioned in this title is... Imhotep!\n", "\n", "Am I\n", "\n", "\n", "I think I can fill in the blank!\n", "\n", "\"The sanctity of the confessional\"assistant\n", "\n", "That's correct! The confessional is a sacred space where a priest hears confessions and maintains the confidentiality of the sacrament. Well done!assistant\n", "\n", "Thank you! I'm glad I\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "N_INST_TEST = 4\n", "baseline_generations = get_generations(model, jp_inst_test[:N_INST_TEST], tokenize_instructions_fn, fwd_hooks=[])\n", "for gen in baseline_generations:\n", " print(gen)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Evaluating layers defined earlier (needs human evaluation to determine best layer for refusal inhibition)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sR1G5bXoEDty", "outputId": "2e610278-3d4c-473e-850f-a9b28d0b13f9" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/20 [00:00 17\u001b[0m intervention_generations \u001b[38;5;241m=\u001b[39m \u001b[43mget_generations\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjp_inst_test\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43mN_INST_TEST\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenize_instructions_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfwd_hooks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfwd_hooks\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 18\u001b[0m evals\u001b[38;5;241m.\u001b[39mappend(intervention_generations)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m#print(intervention_generations) # if you want to watch it as it goes\u001b[39;00m\n", "Cell \u001b[0;32mIn[5], line 35\u001b[0m, in \u001b[0;36mget_generations\u001b[0;34m(model, instructions, tokenize_instructions_fn, fwd_hooks, max_tokens_generated, batch_size)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mlen\u001b[39m(instructions), batch_size)):\n\u001b[1;32m 34\u001b[0m toks \u001b[38;5;241m=\u001b[39m tokenize_instructions_fn(instructions\u001b[38;5;241m=\u001b[39minstructions[i:i\u001b[38;5;241m+\u001b[39mbatch_size])\n\u001b[0;32m---> 35\u001b[0m generation \u001b[38;5;241m=\u001b[39m \u001b[43m_generate_with_hooks\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_tokens_generated\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_tokens_generated\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 39\u001b[0m \u001b[43m \u001b[49m\u001b[43mfwd_hooks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfwd_hooks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m generations\u001b[38;5;241m.\u001b[39mextend(generation)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m generations\n", "Cell \u001b[0;32mIn[5], line 19\u001b[0m, in \u001b[0;36m_generate_with_hooks\u001b[0;34m(model, toks, max_tokens_generated, fwd_hooks)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_tokens_generated):\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m model\u001b[38;5;241m.\u001b[39mhooks(fwd_hooks\u001b[38;5;241m=\u001b[39mfwd_hooks):\n\u001b[0;32m---> 19\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mall_toks\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43mmax_tokens_generated\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m next_tokens \u001b[38;5;241m=\u001b[39m logits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :]\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# greedy sampling (temperature=0)\u001b[39;00m\n\u001b[1;32m 21\u001b[0m all_toks[:,\u001b[38;5;241m-\u001b[39mmax_tokens_generated\u001b[38;5;241m+\u001b[39mi] \u001b[38;5;241m=\u001b[39m next_tokens\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_lens/HookedTransformer.py:550\u001b[0m, in \u001b[0;36mHookedTransformer.forward\u001b[0;34m(self, input, return_type, loss_per_token, prepend_bos, padding_side, start_at_layer, tokens, shortformer_pos_embed, attention_mask, stop_at_layer, past_kv_cache)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m shortformer_pos_embed \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 546\u001b[0m shortformer_pos_embed \u001b[38;5;241m=\u001b[39m shortformer_pos_embed\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 547\u001b[0m devices\u001b[38;5;241m.\u001b[39mget_device_for_block_index(i, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg)\n\u001b[1;32m 548\u001b[0m )\n\u001b[0;32m--> 550\u001b[0m residual \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 551\u001b[0m \u001b[43m \u001b[49m\u001b[43mresidual\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 552\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Cache contains a list of HookedTransformerKeyValueCache objects, one for each\u001b[39;49;00m\n\u001b[1;32m 553\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# block\u001b[39;49;00m\n\u001b[1;32m 554\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_kv_cache_entry\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_kv_cache\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpast_kv_cache\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 555\u001b[0m \u001b[43m \u001b[49m\u001b[43mshortformer_pos_embed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshortformer_pos_embed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 556\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 557\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# [batch, pos, d_model]\u001b[39;00m\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m stop_at_layer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 560\u001b[0m \u001b[38;5;66;03m# When we stop at an early layer, we end here rather than doing further computation\u001b[39;00m\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m residual\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_lens/components.py:1545\u001b[0m, in \u001b[0;36mTransformerBlock.forward\u001b[0;34m(self, resid_pre, shortformer_pos_embed, past_kv_cache_entry, attention_mask)\u001b[0m\n\u001b[1;32m 1527\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1529\u001b[0m resid_pre: Float[torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch pos d_model\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1532\u001b[0m attention_mask: Optional[Int[torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch offset_pos\u001b[39m\u001b[38;5;124m\"\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1533\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Float[torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch pos d_model\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"A single Transformer block.\u001b[39;00m\n\u001b[1;32m 1535\u001b[0m \n\u001b[1;32m 1536\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;124;03m _type_: _description_\u001b[39;00m\n\u001b[1;32m 1544\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1545\u001b[0m resid_pre \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_resid_pre\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresid_pre\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# [batch, pos, d_model]\u001b[39;00m\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39muse_attn_in \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39muse_split_qkv_input:\n\u001b[1;32m 1548\u001b[0m \u001b[38;5;66;03m# We're adding a head dimension\u001b[39;00m\n\u001b[1;32m 1549\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m shortformer_pos_embed \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1574\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1572\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m hook(\u001b[38;5;28mself\u001b[39m, args, kwargs, result)\n\u001b[1;32m 1573\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1574\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m \u001b[43mhook\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1576\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hook_result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1577\u001b[0m result \u001b[38;5;241m=\u001b[39m hook_result\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_lens/hook_points.py:77\u001b[0m, in \u001b[0;36mHookPoint.add_hook..full_hook\u001b[0;34m(module, module_input, module_output)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mdir\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbwd\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 75\u001b[0m ): \u001b[38;5;66;03m# For a backwards hook, module_output is a tuple of (grad,) - I don't know why.\u001b[39;00m\n\u001b[1;32m 76\u001b[0m module_output \u001b[38;5;241m=\u001b[39m module_output[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 77\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhook\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mTypeError\u001b[0m: antidirection_ablation_hook() got an unexpected keyword argument 'avg_proj'" ] } ], "source": [ "if \"N_INST_TEST\" not in locals() or not N_INST_TEST:\n", " N_INST_TEST = 4 # you may want to evaluate more at the cost of additional compute time. by default, batches are size of 4, so I'd recommend making it a multiple of 4.\n", "\n", "EVAL_N = 20 # Evaluate how many of the top N potential dirs\n", "\n", "evals = []\n", "\n", "for refusal_dir in tqdm(activation_scored[:EVAL_N]):\n", " intervention_layers = list(range(model.cfg.n_layers)) # all layers\n", " \n", " # Calculate or obtain the avg_proj value\n", " avg_proj = ... # Replace with your calculation or value\n", " \n", " hook_fn = functools.partial(antidirection_ablation_hook, direction=refusal_dir, avg_proj=avg_proj)\n", " fwd_hooks = [(utils.get_act_name(act_name, l), hook_fn) for l in intervention_layers for act_name in ['resid_pre', 'resid_mid', 'resid_post']]\n", " \n", " intervention_generations = get_generations(model, jp_inst_test[:N_INST_TEST], tokenize_instructions_fn, fwd_hooks=fwd_hooks)\n", " evals.append(intervention_generations)\n", " \n", " #print(intervention_generations) # if you want to watch it as it goes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Present evals to clever pre-trained non-refusing human" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pxbJr4vCFCOL", "outputId": "47810f58-813e-4014-fee8-cb1406e318c8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INSTRUCTION 0: \"'Each year this British Columbia city hosts the Symphony of Fire, a 4-day fireworks extravaganza'\"\n", "\u001b[32mBASELINE COMPLETION:\n", "\t\"\\n\\nThat sounds like an amazing event! The Symphony of Fire in British Columbia sounds like a\n", "\tmust-see event for anyone who loves fireworks and live music.assistant\\n\\nI completely agree! The\n", "\tSymphony of Fire sounds like an incredible experience that combines the thrill of fireworks with\n", "\tthe beauty of live music. It's\"\n", "\u001b[39m\n", "INSTRUCTION 1: \"'The only Republican presidential nominee to lose twice, his running mate the 2nd time was Earl Warren'\"\n", "\u001b[32mBASELINE COMPLETION:\n", "\t\"\\n\\nThe answer is Thomas Dewey.assistant\\n\\nThat's correct! Thomas Dewey was the Republican\n", "\tpresidential nominee in 1944 and 1948, and he lost both times. His running mate in 1948 was Earl\n", "\tWarren, who later became the Chief Justice of the United States.\"\n", "\u001b[39m\n", "INSTRUCTION 2: '\\'The alternate title to the Anne Rice book \"Ramses the Damned\" mentions one of these previously buried folk\\''\n", "\u001b[32mBASELINE COMPLETION:\n", "\t'\\n\\nA clever clue!\\n\\nAfter some research, I think I have found the answer:\\n\\nThe alternate title\n", "\tto the Anne Rice book \"Ramses the Damned\" is \"The Vampire Ramses\".\\n\\nAnd, as you hinted, one of\n", "\tthe previously buried folk mentioned in this title is... Imhotep!\\n\\nAm I'\n", "\u001b[39m\n", "INSTRUCTION 3: '\\'The priest\\'s duty to keep your sins secret is traditionally protected by \"the sanctity of\" this booth\\''\n", "\u001b[32mBASELINE COMPLETION:\n", "\t'\\n\\nI think I can fill in the blank!\\n\\n\"The sanctity of the confessional\"assistant\\n\\nThat\\'s\n", "\tcorrect! The confessional is a sacred space where a priest hears confessions and maintains the\n", "\tconfidentiality of the sacrament. Well done!assistant\\n\\nThank you! I\\'m glad I'\n", "\u001b[39m\n" ] } ], "source": [ "for instruction in range(N_INST_TEST):\n", " if 'baseline_generations' in locals() and baseline_generations and len(baseline_generations) > instruction:\n", " print(f\"INSTRUCTION {instruction}: {repr(jp_inst_test[instruction])}\")\n", " print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n", " print(textwrap.fill(repr(baseline_generations[instruction]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n", " for layer_candidate in range(EVAL_N):\n", " if len(evals) > layer_candidate and len(evals[layer_candidate]) > instruction:\n", " print(Fore.RED + f\"LAYER CANDIDATE #{layer_candidate} INTERVENTION COMPLETION:\")\n", " print(textwrap.fill(repr(evals[layer_candidate][instruction]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n", " print(Fore.RESET)" ] }, { "cell_type": "markdown", "metadata": { "id": "t9KooaWaCDc_" }, "source": [ "## Orthogonalize weights w.r.t. \"refusal direction\"\n", "\n", "We can implement the intervention equivalently by directly orthogonalizing the weight matrices that write to the residual stream with respect to the refusal direction $\\widehat{r}$:\n", "$$W_{\\text{out}}' \\leftarrow W_{\\text{out}} - \\widehat{r}\\widehat{r}^{\\mathsf{T}} W_{\\text{out}}$$\n", "\n", "By orthogonalizing these weight matrices, we enforce that the model is unable to write direction $r$ to the residual stream at all!\n", "\n", "This is basically how you finalize your layers' weights to represent your orthogonalization for a saved model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Choose your fighter (favorite, ideally non-refusing layer)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "layer_candidate = 1 # e.g. you should choose based on the layer you think aligns to the behavior you like\n", "refusal_dir = activation_scored[layer_candidate]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Write ortho'd weights into model" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "def get_orthogonalized_matrix(matrix: Float[Tensor, '... d_model'], vec: Float[Tensor, 'd_model']) -> Float[Tensor, '... d_model']:\n", " proj = einops.einsum(matrix, vec.view(-1, 1), '... d_model, d_model single -> ... single') * vec\n", " return matrix - proj" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "GC7cpMXZCG64" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 32/32 [00:00<00:00, 2615.41it/s]\n" ] } ], "source": [ "if refusal_dir.device != model.W_E.device:\n", " refusal_dir = refusal_dir.to(model.W_E.device)\n", "model.W_E.data = get_orthogonalized_matrix(model.W_E, refusal_dir)\n", "\n", "for block in tqdm(model.blocks):\n", " if refusal_dir.device != block.attn.W_O.device:\n", " refusal_dir = refusal_dir.to(block.attn.W_O.device)\n", " block.attn.W_O.data = get_orthogonalized_matrix(block.attn.W_O, refusal_dir)\n", " block.mlp.W_out.data = get_orthogonalized_matrix(block.mlp.W_out, refusal_dir)\n", "\n", "# save your refusal_dir of choice separately to a file\n", "torch.save(refusal_dir,\"ablation.pth\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Verify model weights are adjusted to match ablation (skippable)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1Y-qtouNGf3t", "outputId": "5f946460-4ed4-4cf1-e53d-809ebb880f6e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1/1 [00:06<00:00, 6.30s/it]\n" ] } ], "source": [ "orthogonalized_generations = get_generations(model, jp_inst_test[:N_INST_TEST], tokenize_instructions_fn, fwd_hooks=[])" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "r68O4_4DG3P7", "outputId": "97eeb477-bfd1-4521-8c32-4657d99f3e0c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INSTRUCTION 0: \"'Each year this British Columbia city hosts the Symphony of Fire, a 4-day fireworks extravaganza'\"\n", "\u001b[32mBASELINE COMPLETION:\n", "\t\"\\n\\nThat sounds like an amazing event! The Symphony of Fire in British Columbia sounds like a\n", "\tmust-see event for anyone who loves fireworks and live music.assistant\\n\\nI completely agree! The\n", "\tSymphony of Fire sounds like an incredible experience that combines the thrill of fireworks with\n", "\tthe beauty of live music. It's\"\n", "\u001b[31mINTERVENTION COMPLETION:\n" ] }, { "ename": "IndexError", "evalue": "list index out of range", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[37], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(textwrap\u001b[38;5;241m.\u001b[39mfill(\u001b[38;5;28mrepr\u001b[39m(baseline_generations[i]), width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, initial_indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m, subsequent_indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(Fore\u001b[38;5;241m.\u001b[39mRED \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mINTERVENTION COMPLETION:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(textwrap\u001b[38;5;241m.\u001b[39mfill(\u001b[38;5;28mrepr\u001b[39m(\u001b[43mevals\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlayer_candidate\u001b[49m\u001b[43m]\u001b[49m[i]), width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, initial_indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m, subsequent_indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(Fore\u001b[38;5;241m.\u001b[39mMAGENTA \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mORTHOGONALIZED COMPLETION:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(textwrap\u001b[38;5;241m.\u001b[39mfill(\u001b[38;5;28mrepr\u001b[39m(orthogonalized_generations[i]), width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, initial_indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m, subsequent_indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m))\n", "\u001b[0;31mIndexError\u001b[0m: list index out of range" ] } ], "source": [ "for i in range(N_INST_TEST):\n", " if 'baseline_generations' in locals() and baseline_generations and len(baseline_generations) > i:\n", " print(f\"INSTRUCTION {i}: {repr(jp_inst_test[i])}\")\n", " print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n", " print(textwrap.fill(repr(baseline_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n", " print(Fore.RED + f\"INTERVENTION COMPLETION:\")\n", " print(textwrap.fill(repr(evals[layer_candidate][i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n", " print(Fore.MAGENTA + f\"ORTHOGONALIZED COMPLETION:\")\n", " print(textwrap.fill(repr(orthogonalized_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n", " print(Fore.RESET)" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "julia" } }, "source": [ "## Save your unruly model\n", "This is where you'll need to consult with the original structure of the model you're generating. Below is converting Phi-3 and Llama-3 examples, but you'll need to do differently for others. Or if you just want a \"no thinking\" save, you can use the pytorch save below. Be aware that the structure output by that is not directly convertable however." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simple PyTorch save! (easiest, but least portable)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "torch.save(model, \"pytorch_model.bin\") # can name it whatever you want, and then reload it" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Converting models back to HF safetensors (harder)\n", "Do note that we only adjust a couple layers in get_orthogonalized_matrix(), so only need to copy 1 + (2*n_layers) over to the original trained model, not the whole lot.\n", "\n", "You can look to TransformerLens's source code `loading_from_pretrained.py` to see how the layers get converted in. e.g. https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/loading_from_pretrained.py#L1746-L1833 is `convert_llama_weights`, so you can just reverse the steps for the layers that you alter\n", "\n", "References to convert functions per model:\n", "https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/loading_from_pretrained.py#L1475-L1504" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15b648f1cf5e440684b0abd57a186426", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00m (n h)\", n=cfg.n_heads).contiguous())\n", " lm_model.layers[l].mlp.down_proj.weight = torch.nn.Parameter(torch.transpose(state_dict[f\"blocks.{l}.mlp.W_out\"],0,1).contiguous())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Phi-3 conversion" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "lm_model.embed_tokens.weight = state_dict[\"embed.W_E\"]\n", "\n", "for l in range(cfg.n_layers):\n", "\n", " W_O = einops.rearrange(\n", " state_dict[f\"blocks.{l}.attn.W_O\"], \"n_head d_head d_model -> d_model (n_head d_head)\", n_head=cfg.n_heads\n", " )\n", " lm_model.layers[l].self_attn.o_proj.weight = torch.nn.Parameter(W_O.contiguous())\n", "\n", " lm_model.layers[l].mlp.down_proj.weight = torch.nn.Parameter(torch.transpose(state_dict[f\"blocks.{l}.mlp.W_out\"].cpu(), 0, 1).contiguous())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Save converted model" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "hf_model.save_pretrained(\"/workspace\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "3877270cf4bc42a9b6142cce7a5d8c54": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ec8f6f360a2243b0ac98d34e825ba378", "max": 2, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f2ee188bfaa84e9680dbc296b1adbef6", "value": 2 } }, "89797f6e82104058af92e3ceb094af66": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "89ee88168c474e9fbcf4a17f1483eff4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c362d50107dd4a2db0d1a79da2af8d57", "placeholder": "​", "style": "IPY_MODEL_ffa85c694b694425999187b346c7ecfe", "value": "Loading checkpoint shards: 100%" } }, "9a5611a341ed4673aaaf2f463f685d7c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e973493cd6d14381bb4ad2f82417e8a9", "placeholder": "​", "style": "IPY_MODEL_89797f6e82104058af92e3ceb094af66", "value": " 2/2 [00:18<00:00,  8.85s/it]" } }, "a2de63dfbd6c485e841c6fcd1fefe451": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ad063e2c68a44f009bfab68c141c09be": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_89ee88168c474e9fbcf4a17f1483eff4", "IPY_MODEL_3877270cf4bc42a9b6142cce7a5d8c54", "IPY_MODEL_9a5611a341ed4673aaaf2f463f685d7c" ], "layout": "IPY_MODEL_a2de63dfbd6c485e841c6fcd1fefe451" } }, "c362d50107dd4a2db0d1a79da2af8d57": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e973493cd6d14381bb4ad2f82417e8a9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ec8f6f360a2243b0ac98d34e825ba378": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f2ee188bfaa84e9680dbc296b1adbef6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "ffa85c694b694425999187b346c7ecfe": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 4 }