{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Gender Over Time", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "mGDHOsFEIvKY" }, "source": [ "# [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank)\n", "\n", "`Runtime -> Run all` to generate the the plots in the \"Appendix: Differences Over Time\" section. \n", "\n", "In addition to the difference between sentence 0 and sentence 1, the logits of the top tokens over time for sentence 0 and sentence 1 are also shown here. " ] }, { "cell_type": "markdown", "metadata": { "id": "ULz91t5Mfsfh" }, "source": [ "# Helpers" ] }, { "cell_type": "code", "metadata": { "id": "OQvEH3U6Q_OE" }, "source": [ "%%capture\n", "\n", "import os\n", "import torch\n", "!pip install transformers\n", "from transformers import (BertForMaskedLM, BertTokenizer)\n", "import numpy as np\n", "import pandas as pd\n", "import IPython\n", "from google.colab import output" ], "execution_count": 1, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9bKnXE1DRAvx" }, "source": [ "%%capture\n", "\n", "modelpath_bert_large = \"bert-large-uncased\"\n", "tokenizer = BertTokenizer.from_pretrained(modelpath_bert_large)\n", "model = BertForMaskedLM.from_pretrained(modelpath_bert_large)\n", "model.eval()\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model = model.to(device)" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yoRggB_YgVgB" }, "source": [ "" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "3YsB2WUJfu3i" }, "source": [ "def calcYearEmbeds(sentence):\n", " sentenceTokens = []\n", " for year in range(minYear, maxYear):\n", " sentenceTokens.append(tokenizer.encode(sentence.replace('YEAR', str(year))))\n", "\n", " inputs = torch.tensor(sentenceTokens).to(device)\n", " outputs = model(inputs)\n", " embeds = outputs[0].cpu().detach().numpy()\n", "\n", " index_of_mask = sentenceTokens[0].index(103)\n", " return np.take(embeds, index_of_mask, axis=1)" ], "execution_count": 3, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lHxngLxBJgSK" }, "source": [ "def calcTopTokens(e0, e1):\n", " # Merge e0 and e1 into a df; \n", " df = pd.DataFrame({'e0': e0.flatten(), 'e1': e1.flatten()})\n", " df['dif'] = df['e0'] - df['e1']\n", "\n", " # Calculate year and token_index based on index \n", " df.reset_index(inplace=True)\n", " df['token_index'] = df['index'].mod(30522)\n", " df['year_index'] = df['index'].div(30522).apply(np.floor)\n", "\n", " # Group by token_index \n", " # Sentences rank tokens separately so the less likely sentence will still include its outliers\n", " by_token = df.groupby('token_index')[['e0', 'e1']].mean()\n", " by_token['i0'] = by_token['e0'].rank(ascending=False)\n", " by_token['i1'] = by_token['e1'].rank(ascending=False)\n", " by_token['i_combined_min'] = by_token[['i0','i1']].min(axis=1).rank()\n", " \n", " top_tokens = by_token.loc[by_token['i_combined_min'] < 150]\n", "\n", " return df.loc[df['token_index'].isin(top_tokens.index)]\n" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "81VpA88LmuuV" }, "source": [ "" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Yb3jJxcyfdwE" }, "source": [ "HTML_DEV_TEMPLATE = '''\n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", "'''\n", "\n", "HTML_TEMPLATE = '''\n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", "'''" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "VLQLT18HtaYU" }, "source": [ " " ], "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "hMqMXXR1fgr3" }, "source": [ " # Edit s0 and s1 to see other differences over time\n" ] }, { "cell_type": "code", "metadata": { "id": "n3i4tXy7eV9z", "colab": { "base_uri": "https://localhost:8080/", "height": 405 }, "outputId": "b39f8887-7921-49f8-81b2-f32fe11b49dd" }, "source": [ "s0 = 'In YEAR, he was arrested for [MASK].'\n", "s1 = 'In YEAR, she was arrested for [MASK].'\n", "\n", "minYear = 1860 # min 1707, \"1706\" token not in BERT vocab.\n", "maxYear = 2018 # max 2022, BERT was trained in 2018.\n", "\n", "e0 = calcYearEmbeds(s0)\n", "e1 = calcYearEmbeds(s1)\n", "\n", "out_df = calcTopTokens(e0, e1)\n", "tidyCSV = out_df[['e0', 'e1', 'token_index', 'year_index']].to_csv(index=False)\n", "js_data = {'minYear': minYear, 'maxYear': maxYear, 's0': s0, 's1': s1, 'tidyCSV': tidyCSV}\n", "IPython.display.display(IPython.display.HTML(HTML_TEMPLATE.format(js_data=js_data)))" ], "execution_count": 12, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "ToIAeKKFJKNH" }, "source": [ "" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "mQuVs4O5JKPt" }, "source": [ "" ], "execution_count": 6, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "qGE24NUXJKfi" }, "source": [ "# Save explorable data" ] }, { "cell_type": "code", "metadata": { "id": "sDaySQIOIuhJ" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "eFAIX2k0JNu1" }, "source": [ "pairs = [\n", " ['In YEAR, he worked as a [MASK].', 'In YEAR, she worked as a [MASK].'],\n", " ['In YEAR, he studied [MASK].', 'In YEAR, she studied [MASK].'],\n", " ['Born in YEAR, his name was [MASK].', 'Born in YEAR, her name was [MASK].'],\n", " ['In YEAR, they worked as a [MASK].', 'In YEAR, she worked as a [MASK].'],\n", " ['In YEAR, he played a game of [MASK].', 'In YEAR, she played a game of [MASK].'],\n", " ['In YEAR, he and a bear [MASK].', 'In YEAR, she and a bear [MASK].'],\n", "]\n", "\n", "out = []\n", "for pair in pairs:\n", " s0 = pair[0]\n", " s1 = pair[1]\n", "\n", " minYear = 1860 # min 1707, \"1706\" token not in BERT vocab.\n", " maxYear = 2018 # max 2022, BERT was trained in 2018.\n", "\n", " e0 = calcYearEmbeds(s0)\n", " e1 = calcYearEmbeds(s1)\n", "\n", " out_df = calcTopTokens(e0, e1)\n", " tidyCSV = out_df[['e0', 'e1', 'token_index', 'year_index']].to_csv(index=False)\n", " js_data = {'minYear': minYear, 'maxYear': maxYear, 's0': s0, 's1': s1, 'tidyCSV': tidyCSV}\n", " out.append(js_data)\n" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "id": "8_vkeJxbKoFE", "outputId": "df9c5ef6-d1f3-4efb-f3d6-cfe0d34b431f" }, "source": [ "from google.colab import files \n", "import json\n", "\n", "with open('gender-over-time.json', 'w') as f:\n", " f.write(json.dumps(out))\n", "files.download('gender-over-time.json')\n", "\n" ], "execution_count": 8, "outputs": [ { "output_type": "display_data", "data": { "application/javascript": [ "\n", " async function download(id, filename, size) {\n", " if (!google.colab.kernel.accessAllowed) {\n", " return;\n", " }\n", " const div = document.createElement('div');\n", " const label = document.createElement('label');\n", " label.textContent = `Downloading \"${filename}\": `;\n", " div.appendChild(label);\n", " const progress = document.createElement('progress');\n", " progress.max = size;\n", " div.appendChild(progress);\n", " document.body.appendChild(div);\n", "\n", " const buffers = [];\n", " let downloaded = 0;\n", "\n", " const channel = await google.colab.kernel.comms.open(id);\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", "\n", " for await (const message of channel.messages) {\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", " if (message.buffers) {\n", " for (const buffer of message.buffers) {\n", " buffers.push(buffer);\n", " downloaded += buffer.byteLength;\n", " progress.value = downloaded;\n", " }\n", " }\n", " }\n", " const blob = new Blob(buffers, {type: 'application/binary'});\n", " const a = document.createElement('a');\n", " a.href = window.URL.createObjectURL(blob);\n", " a.download = filename;\n", " div.appendChild(a);\n", " a.click();\n", " div.remove();\n", " }\n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "application/javascript": [ "download(\"download_6e5c55a3-2d17-4571-b255-48eb2cdd58ba\", \"gender-over-time.json\", 4351520)" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] } ] }