{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "j9x-5xxtgJYA" }, "source": [ "# [Token classification (PyTorch)](https://huggingface.co/learn/nlp-course/en/chapter7/2?fw=pt)\n", "\n", "* The Dataset structure contains rows of dictionary where each dictionary correspond to a data point consisting of features and coresponding values\n", "* `Dataset` object also has a number of methods and attributes\n", "* `DataCollatorForTokenClassification` pads with `-100`\n", "* The traditional framework used to evaluate token classification prediction is [`seqeval`](https://huggingface.co/spaces/evaluate-metric/seqeval)" ] }, { "cell_type": "markdown", "source": [ "#### Main steps to prepare custom training loop\n", "1. Prepare the dataloaders\n", "2. Prepare the model -> model initialisation and passing in 1d2label and label2id\n", "3. Prepare the optimizer\n", "4. Pass all to accelerate.prepare()" ], "metadata": { "id": "4xzr2_FURfQP" } }, { "cell_type": "markdown", "metadata": { "id": "53wt9QqZgJYC" }, "source": [ "Install the Transformers, Datasets, and Evaluate libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EO3w1a41gJYC", "outputId": "3e1b07df-ceb4-4e4a-8aa4-dfc74fc81530" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting datasets\n", " Downloading datasets-2.16.0-py3-none-any.whl (507 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m507.1/507.1 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting evaluate\n", " Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: transformers[sentencepiece] in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (10.0.1)\n", "Collecting pyarrow-hotfix (from datasets)\n", " Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", "Collecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.1)\n", "Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.19.4)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Collecting responses<0.19 (from evaluate)\n", " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (2023.6.3)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (0.15.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (0.4.1)\n", "Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[sentencepiece])\n", " Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (3.20.3)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.4->datasets) (4.5.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.11.17)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", "Installing collected packages: sentencepiece, pyarrow-hotfix, dill, responses, multiprocess, datasets, evaluate\n", "Successfully installed datasets-2.16.0 dill-0.3.7 evaluate-0.4.1 multiprocess-0.70.15 pyarrow-hotfix-0.6 responses-0.18.0 sentencepiece-0.1.99\n", "Collecting accelerate\n", " Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu121)\n", "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.19.4)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.13.1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.5.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2023.6.0)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2023.11.17)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", "Installing collected packages: accelerate\n", "Successfully installed accelerate-0.25.0\n", "\u001b[31mERROR: torch_xla-1.9-cp37-cp37m-linux_x86_64.whl is not a supported wheel on this platform.\u001b[0m\u001b[31m\n", "Reading package lists... Done\n", "Building dependency tree... Done\n", "Reading state information... Done\n", "git-lfs is already the newest version (3.0.2-1ubuntu0.2).\n", "0 upgraded, 0 newly installed, 0 to remove and 24 not upgraded.\n" ] } ], "source": [ "!pip install datasets evaluate transformers[sentencepiece]\n", "!pip install accelerate\n", "# To run the training on TPU, you will need to uncomment the following line:\n", "!pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n", "!apt install git-lfs" ] }, { "cell_type": "markdown", "metadata": { "id": "uyOjlt8igJYD" }, "source": [ "You will need to setup git, adapt your email and name in the following cell." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "GS9z71W1gJYD" }, "outputs": [], "source": [ "!git config --global user.email \"koakande@gmail.com\"\n", "!git config --global user.name \"koakande\"" ] }, { "cell_type": "markdown", "metadata": { "id": "Tk9JVa10gJYE" }, "source": [ "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 331, "referenced_widgets": [ "b09ce087c3134d64839edbc4a568f742", "43aa4e04178e4b5fa50837e4e0db4d55", "ebd1f7f7f4154bca9f0594a2f3351472", "768816c3f54c4e6cbca05c47b1ec34ad", "f18137c6c2f04783a50580de5e5caf3b", "c10791b3c8b3483d8a97007797392ab4", "b17ee1d8d8374223b2bb160758cf2355", "847822b699924086a19f2b9e73e613a2", "d068c6a78af848ddaf388fba271e28f7", "7fee36f9922c49c083d109fb49f48737", "d499f0acba1842e0892b1e0766e6c6c0", "b93985358ab046cd942e6249ca460539", "7d67a4d67b494ceeb6f7d2c7791b26ef", "368077ddf45840a89447a2f8ee5af94e", "2df808b55601428c96e693a6f0d70e52", "c154e37f61324f69b7f99dfaded33e96", "8830863aa8da4402b6ae48678c23a8c8" ] }, "id": "2_b6ALK3gJYE", "outputId": "3adce3a2-adb4-4a5b-9b8d-2927776678fd" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='