{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ewer-Q-0w2xA" }, "source": [ "# Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NpsF9ipLLl2s", "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32" }, "outputs": [], "source": [ "!pip install git+https://github.com/huggingface/transformers/\n", "!pip install git+https://github.com/google/flax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M1wVkrpjU6zO" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": { "id": "t47CH1H_IOT8" }, "source": [ "# Custom BART Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9jQnM6S2vCpn" }, "outputs": [], "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_eEaJVxAKpV5" }, "outputs": [], "source": [ "import jax\n", "import flax.linen as nn\n", "\n", "from transformers.models.bart.modeling_flax_bart import *\n", "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n", "\n", "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S7CP9Td9m2ge", "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d" }, "outputs": [], "source": [ "# load pre-trained model for encoder weights\n", "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6lmynR-poceH" }, "outputs": [], "source": [ "# set up our new model config\n", "config = BartConfig.from_pretrained(BASE_MODEL)\n", "config.tie_word_embeddings = False\n", "config.decoder_start_token_id = BOS_TOKEN_ID\n", "config.bos_token_id = BOS_TOKEN_ID # should not be used\n", "config.pos_token_id = BOS_TOKEN_ID # should not be used\n", "#config.eos_token_id = None # prevents generation from stopping until we reach max_length" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_6-XKK40oEfP" }, "outputs": [], "source": [ "# create our model and initialize it randomly\n", "model = CustomFlaxBartForConditionalGeneration(config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-r_hZestr-NR" }, "outputs": [], "source": [ "# use pretrained weights\n", "model.params['model']['encoder'] = base_model.params['model']['encoder']\n", "model.params['model']['shared'] = base_model.params['model']['shared']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5NEX8f62sVjx" }, "outputs": [], "source": [ "# no need for base_model anymore\n", "del base_model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jz032w73nHEf", "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49" }, "outputs": [], "source": [ "# we verify that the shape has not been modified\n", "model.params['final_logits_bias'].shape" ] }, { "cell_type": "markdown", "metadata": { "id": "zLl24Ez5t7x1" }, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XLLA2NK3uDQr" }, "outputs": [], "source": [ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ntow53I_t81D", "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac" }, "outputs": [], "source": [ "text = \"My friends are cool but they eat too many carbs.\"\n", "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n", "encoder_outputs = model.encode(**inputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vcRNJnJ_uJOJ", "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711" }, "outputs": [], "source": [ "decoder_start_token_id = model.config.decoder_start_token_id\n", "decoder_start_token_id" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6QWmEwL_uMld" }, "outputs": [], "source": [ "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n", "outputs = model.decode(decoder_input_ids, encoder_outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c_ys3yWBothF", "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53" }, "outputs": [], "source": [ "outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O6s0wtB_uTC_", "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66" }, "outputs": [], "source": [ "outputs.logits.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ELzemGP3uBzy", "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885" }, "outputs": [], "source": [ "outputs.logits.argmax(axis=-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fQjikkGEunpx", "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1" }, "outputs": [], "source": [ "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P32mJJSbrU1F" }, "outputs": [], "source": [ "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C7cHbIHruELT" }, "outputs": [], "source": [ "greedy_output = model.generate(input_ids_test, max_length=50)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jYugh9cOuwc9", "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537" }, "outputs": [], "source": [ "greedy_output[0]" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "machine_shape": "hm", "name": "CustomBARTv4b-model-generate.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }