{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kqOqTZuKeJoa", "outputId": "9f63819c-9bc1-4c15-e9cd-9c1121edd2a6" }, "outputs": [], "source": [ "#!pip install \"jax[tpu]>=0.2.16\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SQ-lhEVFeY4d", "outputId": "7346c6b8-1848-4755-c114-94d6de50b50d" }, "outputs": [], "source": [ "#!git clone https://github.com/huggingface/transformers.git" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9qSTMLvFfBVs", "outputId": "40659f61-86d4-4ae5-9262-501557737705" }, "outputs": [], "source": [ "#!pip install ./transformers" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "7Og7zRTrfm08" }, "outputs": [], "source": [ "#!pip install jaxlib>=0.2.9" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "nmQv7VMaf1L8" }, "outputs": [], "source": [ "#!pip install flax>=0.3.4" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "MT6jpop-f4dc" }, "outputs": [], "source": [ "#!pip install optax>=0.0.9" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# %%capture\n", "# !pip install jupyterlab_widgets\n", "# !pip install ipywidgets" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "-F5NIqDmfDLb" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2021-07-08 10:11:37.310929: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n" ] } ], "source": [ "from transformers import (\n", " CONFIG_MAPPING,\n", " FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n", " BatchEncoding,\n", " FlaxT5ForConditionalGeneration,\n", " T5ForConditionalGeneration,\n", " HfArgumentParser,\n", " PreTrainedTokenizerBase,\n", " T5Config,\n", " T5TokenizerFast,\n", " TrainingArguments,\n", " is_tensorboard_available,\n", " set_seed,\n", ")\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "aInICxY6gREQ" }, "outputs": [], "source": [ "import flax\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "from flax import jax_utils, traverse_util\n", "from flax.training import train_state\n", "from flax.training.common_utils import get_metrics, onehot, shard\n", "from transformers import (\n", " CONFIG_MAPPING,\n", " FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n", " BatchEncoding,\n", " FlaxT5ForConditionalGeneration,\n", " T5ForConditionalGeneration,\n", " HfArgumentParser,\n", " PreTrainedTokenizerBase,\n", " T5Config,\n", " T5TokenizerFast,\n", " TrainingArguments,\n", " is_tensorboard_available,\n", " set_seed,\n", ")\n", "from transformers.models.t5.modeling_flax_t5 import shift_tokens_right\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "iEqVlHptfOCT" }, "outputs": [], "source": [ "tokenizer = T5TokenizerFast.from_pretrained(\"t5-small\")\n", "config = T5Config.from_pretrained(\"t5-small\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LNETw3cWfjbr", "outputId": "95c0e750-c087-46dd-92fa-39f8ff0238f2" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Starting the local TPU driver.\n", "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n", "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n" ] } ], "source": [ "model = FlaxT5ForConditionalGeneration(config)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "T5F3BEA2f6xE" }, "outputs": [], "source": [ "input_ids = np.asarray(208 * [512 * [1]], dtype=np.int32)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def run_forward(input_ids, params):\n", " return model(input_ids, decoder_input_ids=input_ids).logits" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "jitted_forward = jax.jit(run_forward)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "logits = jitted_forward(input_ids, model.params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "TPU", "colab": { "name": "Untitled1.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }