{ "cells": [ { "cell_type": "markdown", "id": "b3fc8862-0c2b-45f3-badf-e591c7b8f891", "metadata": {}, "source": [ "# Token Count Exploration\n", "It would be really useful for deployment to know our input/output expectations. We know that our output is quite verbose relative to the input since the explanations are long. With a model like `mistralai/Mistral-7B-Instruct-v0.3` Id expect that our real output with explanations will be shorter. Thats perfect since our training data will give us a reliable upper bound, which is great to prevent truncation.\n", "\n", "Lets figure out how to split input and output tokens, and then we can build a histogram." ] }, { "cell_type": "markdown", "id": "3a501f2f-ba98-4c0f-aa30-f4768bd80dcb", "metadata": {}, "source": [ "## Config" ] }, { "cell_type": "code", "execution_count": 1, "id": "5d0bd22f-293e-4c15-9dfe-8070553f42b5", "metadata": { "tags": [] }, "outputs": [], "source": [ "INPUT_DATASET = 'derek-thomas/labeled-multiple-choice-explained-mistral-tokenized'\n", "BASE_MODEL = 'mistralai/Mistral-7B-Instruct-v0.3'" ] }, { "cell_type": "markdown", "id": "c1c3b00c-17bf-4b00-9ee7-d10c598c53e9", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "id": "af2330f3-403c-401c-8028-46ae4971546e", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d675da3076694064ba0c69ed97f938f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
[INST] Answer the Question and include your Reasoning and the Final Answer in a json like: {\"Reasoning: \"...\", \"Final Answer\": \"x\"} where x is a letter that corresponds to the answer choice which is a letter between a and h.\\nQuestion: What can genetic material have?\\nAnswer Choices: (a) Resistance (b) Mutations (c) Clorophyll (d) Nucleotide (e) Symmetry (f) Allow growth (g) Contamination (h) Warmth[/INST] {\\'Reasoning\\': \\'a) Resistance: Genetic material can carry genes that provide resistance to certain diseases or environmental factors, but this is not a characteristic of genetic material itself. Therefore, this option is incorrect.\\\\n\\\\nc) Chlorophyll: Chlorophyll is a pigment found in plants that is responsible for photosynthesis. It is not a characteristic of genetic material. Therefore, this option is incorrect.\\\\n\\\\nd) Nucleotide: Nucleotides are the building blocks of DNA and RNA, which are types of genetic material. However, this option is too broad and does not fully answer the question. Therefore, this option is incorrect.\\\\n\\\\ne) Symmetry: Symmetry is a characteristic of physical objects and organisms, but it is not a characteristic of genetic material. Therefore, this option is incorrect.\\\\n\\\\nf) Allow growth: Genetic material provides the instructions for the growth and development of organisms, but it is not a characteristic of genetic material itself. Therefore, this option is incorrect.\\\\n\\\\ng) Contamination: Contamination is the presence of unwanted substances or impurities, and it is not a characteristic of genetic material. Therefore, this option is incorrect.\\\\n\\\\nh) Warmth: Warmth is a physical property of objects and is not related to genetic material. Therefore, this option is incorrect.\\\\n\\\\nIn conclusion, the only option that correctly describes a characteristic of genetic material is b) mutations. Genetic material can have mutations, which are changes in the DNA sequence that can lead to genetic variation and evolution.\\', \\'Final Answer\\': \\'b\\'}'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['conversation_RFA_sg_gpt3_5'].iloc[0]" ] }, { "cell_type": "code", "execution_count": 6, "id": "0dc985d7-32e3-413f-8640-55829da19838", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[4]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.encode('[/INST]', add_special_tokens=False)" ] }, { "cell_type": "code", "execution_count": 8, "id": "bc9b3856-7652-483c-8dbc-2b9bdc85f9d7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 3, 27075, 1040, 23246, 1072, 3792, 1342, 2066, 2180, 1056, 1072, 1040, 10990, 27075, 1065, 1032, 8379, 1505, 29515, 10598, 20569, 1056, 29515, 1113, 1869, 1316, 1113, 18268, 27075, 2032, 1113, 29512, 18163, 1738, 2086, 1117, 1032, 6266, 1137, 17303, 1066, 1040, 5140, 5550, 1458, 1117, 1032, 6266, 2212, 1032, 1072, 1063, 29491, 781, 25762, 29515, 2592, 1309, 20637, 4156, 1274, 29572, 781, 3588, 17749, 26173, 1982, 29515, 1093, 29476, 29499, 2760, 5400, 1093, 29494, 29499, 17737, 1465, 1093, 29485, 29499, 2134, 1039, 3894, 20298, 1093, 29483, 29499, 1186, 2253, 1059, 1090, 1315, 1093, 29474, 29499, 13124, 17409, 1093, 29490, 29499, 26780, 6825, 1093, 29489, 29499, 3767, 26682, 1093, 29484, 29499, 1162, 2553, 1130, 4, 12780, 20569, 1056, 2637, 1232, 29476, 29499, 2760, 5400, 29515, 7010, 11130, 4156, 1309, 7864, 24971, 1137, 3852, 13336, 1066, 3320, 19025, 1210, 13275, 9380, 29493, 1330, 1224, 1117, 1227, 1032, 18613, 1070, 20637, 4156, 4605, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 15538, 29499, 1457, 6406, 3894, 20298, 29515, 1457, 6406, 3894, 20298, 1117, 1032, 19726, 1234, 2187, 1065, 10691, 1137, 1117, 8100, 1122, 9654, 29492, 1216, 22305, 29491, 1429, 1117, 1227, 1032, 18613, 1070, 20637, 4156, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 1060, 29499, 1186, 2253, 1059, 1090, 1315, 29515, 1186, 2253, 1059, 1090, 2694, 1228, 1040, 4435, 10246, 1070, 16775, 1072, 1167, 4152, 29493, 1458, 1228, 5282, 1070, 20637, 4156, 29491, 3761, 29493, 1224, 4319, 1117, 2136, 6609, 1072, 2003, 1227, 6662, 5140, 1040, 3764, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 1253, 29499, 13124, 17409, 29515, 13124, 17409, 1117, 1032, 18613, 1070, 6045, 7465, 1072, 2938, 11589, 29493, 1330, 1146, 1117, 1227, 1032, 18613, 1070, 20637, 4156, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 24412, 29499, 26780, 6825, 29515, 7010, 11130, 4156, 6080, 1040, 12150, 1122, 1040, 6825, 1072, 4867, 1070, 2938, 11589, 29493, 1330, 1146, 1117, 1227, 1032, 18613, 1070, 20637, 4156, 4605, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 1585, 29499, 3767, 26682, 29515, 3767, 26682, 1117, 1040, 7471, 1070, 13460, 8034, 1851, 9500, 1210, 3592, 1092, 1986, 29493, 1072, 1146, 1117, 1227, 1032, 18613, 1070, 20637, 4156, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 25779, 29499, 1162, 2553, 1130, 29515, 1162, 2553, 1130, 1117, 1032, 6045, 4089, 1070, 7465, 1072, 1117, 1227, 5970, 1066, 20637, 4156, 29491, 9237, 29493, 1224, 4319, 1117, 17158, 6691, 29479, 29524, 29479, 1425, 13654, 29493, 1040, 1633, 4319, 1137, 13510, 14734, 1032, 18613, 1070, 20637, 4156, 1117, 1055, 29499, 5316, 1465, 29491, 7010, 11130, 4156, 1309, 1274, 5316, 1465, 29493, 1458, 1228, 5203, 1065, 1040, 16775, 8536, 1137, 1309, 2504, 1066, 20637, 19191, 1072, 10963, 13775, 1232, 18268, 27075, 2637, 1232, 29494, 15259, 2]\n" ] }, { "data": { "text/plain": [ "[4]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(tokenizer.encode(df['conversation_RFA_sg_gpt3_5'].iloc[0], add_special_tokens=False))\n", "tokenizer.encode('[/INST]', add_special_tokens=False)" ] }, { "cell_type": "markdown", "id": "677e792e-a85f-448c-ab36-ed0aec84ca8e", "metadata": {}, "source": [ "Great, we can see that there is a special token `[/INST]` that we will want to split on. We can count the tokens before and including `[/INST]` and that should be our input tokens, and the tokens after will be our output tokens.\n", "\n", "Lets count those for each row in `conversation_RFA` and build a histogram of the results. `conversation_RFA` should be a good max since its just a reshuffle or superset of the other columns." ] }, { "cell_type": "code", "execution_count": 20, "id": "3c8cd920-4d58-4b1d-b172-098c35dcdfbf", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import pandas as pd\n", "from datasets import load_dataset\n", "from transformers import AutoTokenizer\n", "\n", "# Load the dataset and convert it to a DataFrame\n", "dataset = load_dataset(INPUT_DATASET, split='test')\n", "df = dataset.to_pandas()\n", "\n", "df_token_gpt3_5 = df[['conversation_RFA_sg_gpt3_5']].copy()\n", "df_token_gpt3_5['tokens_gpt3_5'] = df['conversation_RFA_sg_gpt3_5'].apply(lambda x: tokenizer.encode(x))\n", "\n", "df_token_mistral = df[['conversation_RFA_sg_mistral']].copy()\n", "df_token_mistral['tokens_mistral'] = df['conversation_RFA_sg_mistral'].apply(lambda x: tokenizer.encode(x))\n", "\n", "def split_and_measure(lst):\n", " if 4 in lst:\n", " index_of_4 = lst.index(4)\n", " length_before = index_of_4 + 1 # Including 4\n", " length_after = len(lst) - length_before\n", " return length_before, length_after\n", " else:\n", " return None, len(lst) # If 4 is not present\n", "\n", "df_token_gpt3_5[['input_tokens', 'output_tokens']] = df_token_gpt3_5['tokens_gpt3_5'].apply(split_and_measure).apply(pd.Series)\n", "df_token_mistral[['input_tokens', 'output_tokens']] = df_token_mistral['tokens_mistral'].apply(split_and_measure).apply(pd.Series)" ] }, { "cell_type": "code", "execution_count": 22, "id": "9b23b7a3-5448-4b2e-9253-5d1b66ef1e0a", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Plot the histograms\n", "plt.figure(figsize=(10, 6))\n", "\n", "# Histogram for Input Tokens\n", "plt.hist(df_token_gpt3_5['input_tokens'], bins=10, alpha=0.6, label='Input Tokens')\n", "\n", "# Histogram for Output Tokens\n", "plt.hist(df_token_gpt3_5['output_tokens'], bins=10, alpha=0.6, label='Output Tokens')\n", "\n", "# Add titles and labels\n", "plt.title(\"Token Summary\")\n", "plt.xlabel(\"Token Count\")\n", "plt.ylabel(\"Frequency\")\n", "plt.legend()\n", "\n", "# Show the plot\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "9d81d486-bafd-454b-9a44-934ec111ad4d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our Max Input Tokens:\t162\n", "Our Max Output Tokens:\t572\n" ] } ], "source": [ "print(f\"Our Max Input Tokens:\\t{max(df_token_gpt3_5.input_tokens)}\\nOur Max Output Tokens:\\t{max(df_token_gpt3_5.output_tokens)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e6e235c3-75f4-48dd-b0cb-d7cc42426e69", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 24, "id": "7dea222b-a974-4ff6-9e3c-07de766b76c4", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Plot the histograms\n", "plt.figure(figsize=(10, 6))\n", "\n", "# Histogram for Input Tokens\n", "plt.hist(df_token_mistral['input_tokens'], bins=10, alpha=0.6, label='Input Tokens')\n", "\n", "# Histogram for Output Tokens\n", "plt.hist(df_token_mistral['output_tokens'], bins=10, alpha=0.6, label='Output Tokens')\n", "\n", "# Add titles and labels\n", "plt.title(\"Token Summary\")\n", "plt.xlabel(\"Token Count\")\n", "plt.ylabel(\"Frequency\")\n", "plt.legend()\n", "\n", "# Show the plot\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 26, "id": "d6a78d92-2fc4-4354-8825-b17cba59eee4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our Max Input Tokens:\t162\n", "Our Max Output Tokens:\t1148\n" ] } ], "source": [ "print(f\"Our Max Input Tokens:\\t{max(df_token_mistral.input_tokens)}\\nOur Max Output Tokens:\\t{max(df_token_mistral.output_tokens)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "fdfc0581-1c57-436c-8c76-9bfeab278603", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }