{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "CustomBARTv4b-model-generate.ipynb", "provenance": [], "collapsed_sections": [], "machine_shape": "hm" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "TPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "ewer-Q-0w2xA" }, "source": [ "# Installation" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NpsF9ipLLl2s", "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32" }, "source": [ "!pip install git+https://github.com/huggingface/transformers/\n", "!pip install git+https://github.com/google/flax" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Collecting git+https://github.com/huggingface/transformers/\n", " Cloning https://github.com/huggingface/transformers/ to /tmp/pip-req-build-oxejx1op\n", " Running command git clone -q https://github.com/huggingface/transformers/ /tmp/pip-req-build-oxejx1op\n", " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied (use --upgrade to upgrade): transformers==4.9.0.dev0 from git+https://github.com/huggingface/transformers/ in /usr/local/lib/python3.7/dist-packages\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (1.19.5)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (20.9)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (5.4.1)\n", "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.45)\n", "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.6.0)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.41.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (3.0.12)\n", "Requirement already satisfied: huggingface-hub==0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.12)\n", "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.10.3)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2019.12.20)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2.23.0)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.9.0.dev0) (2.4.7)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.15.0)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.0.1)\n", "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (7.1.2)\n", "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.7.4.3)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.4.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2021.5.30)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (3.0.4)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (1.24.3)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2.10)\n", "Building wheels for collected packages: transformers\n", " Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for transformers: filename=transformers-4.9.0.dev0-cp37-none-any.whl size=2582229 sha256=249c593273ccca3027c6427d2c6fd749a89f21d722d628d97eb438a2cf3185a8\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-l2rqt1b7/wheels/61/69/33/974fccec4d0ab5feee9fe83bd93e680d269a805be9ede5ec60\n", "Successfully built transformers\n", "Collecting git+https://github.com/google/flax\n", " Cloning https://github.com/google/flax to /tmp/pip-req-build-rt9g1_wx\n", " Running command git clone -q https://github.com/google/flax /tmp/pip-req-build-rt9g1_wx\n", "Requirement already satisfied (use --upgrade to upgrade): flax==0.3.4 from git+https://github.com/google/flax in /usr/local/lib/python3.7/dist-packages\n", "Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.19.5)\n", "Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.2.13)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (3.2.2)\n", "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.0.2)\n", "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.0.9)\n", "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (3.3.0)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (0.12.0)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.8.1)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (0.10.0)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.4.7)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (1.3.1)\n", "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.0.8)\n", "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.1.66+cuda110)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->flax==0.3.4) (1.15.0)\n", "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.1.6)\n", "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.11.1)\n", "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.12)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.4.1)\n", "Building wheels for collected packages: flax\n", " Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for flax: filename=flax-0.3.4-cp37-none-any.whl size=184692 sha256=503b27995f372afe33631e71572d5edc1fffd4d2e0a4cd206d291ad6b0e4c299\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-g1pzxnv6/wheels/3d/26/f4/0ea6051d7352289d9e4f8178348452b35a9a97bde6035405a5\n", "Successfully built flax\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "M1wVkrpjU6zO" }, "source": [ "%load_ext autoreload\n", "%autoreload 2" ], "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "t47CH1H_IOT8" }, "source": [ "# Custom BART Model" ] }, { "cell_type": "code", "metadata": { "id": "9jQnM6S2vCpn" }, "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large-cnn'" ], "execution_count": 3, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_eEaJVxAKpV5" }, "source": [ "import jax\n", "import flax.linen as nn\n", "\n", "from transformers.models.bart.modeling_flax_bart import *\n", "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n", "\n", "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "S7CP9Td9m2ge", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d" }, "source": [ "# load pre-trained model for encoder weights\n", "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "6lmynR-poceH" }, "source": [ "# set up our new model config\n", "config = BartConfig.from_pretrained(BASE_MODEL)\n", "config.tie_word_embeddings = False\n", "config.decoder_start_token_id = BOS_TOKEN_ID\n", "config.bos_token_id = BOS_TOKEN_ID # should not be used\n", "config.pos_token_id = BOS_TOKEN_ID # should not be used\n", "#config.eos_token_id = None # prevents generation from stopping until we reach max_length" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_6-XKK40oEfP" }, "source": [ "# create our model and initialize it randomly\n", "model = CustomFlaxBartForConditionalGeneration(config)" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "-r_hZestr-NR" }, "source": [ "# use pretrained weights\n", "model.params['model']['encoder'] = base_model.params['model']['encoder']\n", "model.params['model']['shared'] = base_model.params['model']['shared']" ], "execution_count": 8, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "5NEX8f62sVjx" }, "source": [ "# no need for base_model anymore\n", "del base_model" ], "execution_count": 9, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jz032w73nHEf", "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49" }, "source": [ "# we verify that the shape has not been modified\n", "model.params['final_logits_bias'].shape" ], "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(1, 16385)" ] }, "metadata": { "tags": [] }, "execution_count": 10 } ] }, { "cell_type": "markdown", "metadata": { "id": "zLl24Ez5t7x1" }, "source": [ "## Inference" ] }, { "cell_type": "code", "metadata": { "id": "XLLA2NK3uDQr" }, "source": [ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ntow53I_t81D", "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac" }, "source": [ "text = \"My friends are cool but they eat too many carbs.\"\n", "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n", "encoder_outputs = model.encode(**inputs)" ], "execution_count": 12, "outputs": [ { "output_type": "stream", "text": [ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vcRNJnJ_uJOJ", "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711" }, "source": [ "decoder_start_token_id = model.config.decoder_start_token_id\n", "decoder_start_token_id" ], "execution_count": 13, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "16384" ] }, "metadata": { "tags": [] }, "execution_count": 13 } ] }, { "cell_type": "code", "metadata": { "id": "6QWmEwL_uMld" }, "source": [ "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n", "outputs = model.decode(decoder_input_ids, encoder_outputs)" ], "execution_count": 14, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c_ys3yWBothF", "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53" }, "source": [ "outputs" ], "execution_count": 15, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "FlaxCausalLMOutputWithCrossAttentions([('logits',\n", " DeviceArray([[[ 0.5263986 , -2.0947676 , -0.18830685, ..., 0.7599884 ,\n", " 0.6746795 , -1.0411576 ]]], dtype=float32))])" ] }, "metadata": { "tags": [] }, "execution_count": 15 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O6s0wtB_uTC_", "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66" }, "source": [ "outputs.logits.shape" ], "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(1, 1, 16385)" ] }, "metadata": { "tags": [] }, "execution_count": 16 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ELzemGP3uBzy", "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885" }, "source": [ "outputs.logits.argmax(axis=-1)" ], "execution_count": 17, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DeviceArray([[12459]], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 17 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fQjikkGEunpx", "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1" }, "source": [ "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id" ], "execution_count": 18, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(16384, 2, 1)" ] }, "metadata": { "tags": [] }, "execution_count": 18 } ] }, { "cell_type": "code", "metadata": { "id": "P32mJJSbrU1F" }, "source": [ "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')" ], "execution_count": 19, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "C7cHbIHruELT" }, "source": [ "greedy_output = model.generate(input_ids_test, max_length=50)" ], "execution_count": 20, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jYugh9cOuwc9", "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537" }, "source": [ "greedy_output[0]" ], "execution_count": 21, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DeviceArray([[16384, 0, 3570, 13405, 10186, 2392, 16362, 1869,\n", " 15772, 13546, 15772, 13546, 9348, 14791, 15772, 15772,\n", " 15772, 11272, 15772, 13546, 15772, 15772, 13546, 15772,\n", " 13546, 15772, 6642, 15772, 10776, 6431, 15772, 14567,\n", " 13406, 15772, 14567, 6235, 15772, 4909, 16160, 568,\n", " 4664, 6650, 8952, 9089, 15772, 5952, 7375, 10843,\n", " 8952, 2]], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 21 } ] } ] }