{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "99FBiGH7bsfn" }, "source": [ "# Compiling \u0026 Visualizing Tracr Models\n", "\n", "This notebook demonstrates how to compile a tracr model and provides some tools visualize the model's residual stream or layer outputs for a given input sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qm-PM1PEawCx" }, "outputs": [], "source": [ "#@title Imports\n", "import jax\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# The default of float16 can lead to discrepancies between outputs of\n", "# the compiled model and the RASP program.\n", "jax.config.update('jax_default_matmul_precision', 'float32')\n", "\n", "from tracr.compiler import compiling\n", "from tracr.compiler import lib\n", "from tracr.rasp import rasp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HtOAc_yWawFR" }, "outputs": [], "source": [ "#@title Plotting functions\n", "def tidy_label(label, value_width=5):\n", " if ':' in label:\n", " label, value = label.split(':')\n", " else:\n", " value = ''\n", " return label + f\":{value:\u003e{value_width}}\"\n", "\n", "\n", "def add_residual_ticks(model, value_width=5, x=False, y=True):\n", " if y:\n", " plt.yticks(\n", " np.arange(len(model.residual_labels))+0.5, \n", " [tidy_label(l, value_width=value_width)\n", " for l in model.residual_labels], \n", " family='monospace',\n", " fontsize=20,\n", " )\n", " if x:\n", " plt.xticks(\n", " np.arange(len(model.residual_labels))+0.5, \n", " [tidy_label(l, value_width=value_width)\n", " for l in model.residual_labels], \n", " family='monospace',\n", " rotation=90,\n", " fontsize=20,\n", " )\n", "\n", "\n", "def plot_computation_trace(model,\n", " input_labels,\n", " residuals_or_outputs,\n", " add_input_layer=False,\n", " figsize=(12, 9)):\n", " fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)\n", " value_width = max(map(len, map(str, input_labels))) + 1\n", "\n", " for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):\n", " plt.sca(ax)\n", " plt.pcolormesh(layer[0].T, vmin=0, vmax=1)\n", " if i == 0:\n", " add_residual_ticks(model, value_width=value_width)\n", " plt.xticks(\n", " np.arange(len(input_labels))+0.5,\n", " input_labels,\n", " rotation=90,\n", " fontsize=20,\n", " )\n", " if add_input_layer and i == 0:\n", " title = 'Input'\n", " else:\n", " layer_no = i - 1 if add_input_layer else i\n", " layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'\n", " title = f'{layer_type} {layer_no // 2 + 1}'\n", " plt.title(title, fontsize=20)\n", "\n", "\n", "def plot_residuals_and_input(model, inputs, figsize=(12, 9)):\n", " \"\"\"Applies model to inputs, and plots the residual stream at each layer.\"\"\"\n", " model_out = assembled_model.apply(inputs)\n", " residuals = np.concatenate([model_out.input_embeddings[None, ...],\n", " model_out.residuals], axis=0)\n", " plot_computation_trace(\n", " model=model,\n", " input_labels=inputs,\n", " residuals_or_outputs=residuals,\n", " add_input_layer=True,\n", " figsize=figsize)\n", "\n", "\n", "def plot_layer_outputs(model, inputs, figsize=(12, 9)):\n", " \"\"\"Applies model to inputs, and plots the outputs of each layer.\"\"\"\n", " model_out = assembled_model.apply(inputs)\n", " plot_computation_trace(\n", " model=model,\n", " input_labels=inputs,\n", " residuals_or_outputs=model_out.layer_outputs,\n", " add_input_layer=False,\n", " figsize=figsize)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "8hV0nv_ISmhM" }, "outputs": [], "source": [ "#@title Define RASP programs\n", "def get_program(program_name, max_seq_len):\n", " \"\"\"Returns RASP program and corresponding token vocabulary.\"\"\"\n", " if program_name == \"length\":\n", " vocab = {\"a\", \"b\", \"c\", \"d\"}\n", " program = lib.make_length()\n", " elif program_name == \"frac_prevs\":\n", " vocab = {\"a\", \"b\", \"c\", \"x\"}\n", " program = lib.make_frac_prevs((rasp.tokens == \"x\").named(\"is_x\"))\n", " elif program_name == \"dyck-2\":\n", " vocab = {\"(\", \")\", \"{\", \"}\"}\n", " program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\"])\n", " elif program_name == \"dyck-3\":\n", " vocab = {\"(\", \")\", \"{\", \"}\", \"[\", \"]\"}\n", " program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\", \"[]\"])\n", " elif program_name == \"sort\":\n", " vocab = {1, 2, 3, 4, 5}\n", " program = lib.make_sort(\n", " rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)\n", " elif program_name == \"sort_unique\":\n", " vocab = {1, 2, 3, 4, 5}\n", " program = lib.make_sort_unique(rasp.tokens, rasp.tokens)\n", " elif program_name == \"hist\":\n", " vocab = {\"a\", \"b\", \"c\", \"d\"}\n", " program = lib.make_hist()\n", " elif program_name == \"sort_freq\":\n", " vocab = {\"a\", \"b\", \"c\", \"d\"}\n", " program = lib.make_sort_freq(max_seq_len=max_seq_len)\n", " elif program_name == \"pair_balance\":\n", " vocab = {\"(\", \")\"}\n", " program = lib.make_pair_balance(\n", " sop=rasp.tokens, open_token=\"(\", close_token=\")\")\n", " else:\n", " raise NotImplementedError(f\"Program {program_name} not implemented.\")\n", " return program, vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L_m_ufaua9ri" }, "outputs": [], "source": [ "#@title: Assemble model\n", "program_name = \"sort_unique\" #@param [\"length\", \"frac_prevs\", \"dyck-2\", \"dyck-3\", \"sort\", \"sort_unique\", \"hist\", \"sort_freq\", \"pair_balance\"]\n", "max_seq_len = 5 #@param {label: \"Test\", type: \"integer\"}\n", "\n", "program, vocab = get_program(program_name=program_name,\n", " max_seq_len=max_seq_len)\n", "\n", "print(f\"Compiling...\")\n", "print(f\" Program: {program_name}\")\n", "print(f\" Input vocabulary: {vocab}\")\n", "print(f\" Context size: {max_seq_len}\")\n", "\n", "assembled_model = compiling.compile_rasp_to_model(\n", " program=program,\n", " vocab=vocab,\n", " max_seq_len=max_seq_len,\n", " causal=False,\n", " compiler_bos=\"bos\",\n", " compiler_pad=\"pad\",\n", " mlp_exactness=100)\n", "\n", "print(\"Done.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtwiE-JiXF3F" }, "outputs": [], "source": [ "#@title Forward pass\n", "assembled_model.apply([\"bos\", 3, 4, 1]).decoded" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RkEkVcEHa2gf" }, "outputs": [], "source": [ "#@title Plot residual stream\n", "plot_residuals_and_input(\n", " model=assembled_model,\n", " inputs=[\"bos\", 3, 4, 1],\n", " figsize=(10, 9)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8c4LakWHa4ey" }, "outputs": [], "source": [ "#@title Plot layer outputs\n", "plot_layer_outputs(\n", " model=assembled_model,\n", " inputs = [\"bos\", 3, 4, 1],\n", " figsize=(8, 9)\n", ")" ] } ], "metadata": { "colab": { "private_outputs": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }