{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j9x-5xxtgJYA"
},
"source": [
"# [Token classification (PyTorch)](https://huggingface.co/learn/nlp-course/en/chapter7/2?fw=pt)\n",
"\n",
"* The Dataset structure contains rows of dictionary where each dictionary correspond to a data point consisting of features and coresponding values\n",
"* `Dataset` object also has a number of methods and attributes\n",
"* `DataCollatorForTokenClassification` pads with `-100`\n",
"* The traditional framework used to evaluate token classification prediction is [`seqeval`](https://huggingface.co/spaces/evaluate-metric/seqeval)"
]
},
{
"cell_type": "markdown",
"source": [
"#### Main steps to prepare custom training loop\n",
"1. Prepare the dataloaders\n",
"2. Prepare the model -> model initialisation and passing in 1d2label and label2id\n",
"3. Prepare the optimizer\n",
"4. Pass all to accelerate.prepare()"
],
"metadata": {
"id": "4xzr2_FURfQP"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "53wt9QqZgJYC"
},
"source": [
"Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EO3w1a41gJYC",
"outputId": "3e1b07df-ceb4-4e4a-8aa4-dfc74fc81530"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting datasets\n",
" Downloading datasets-2.16.0-py3-none-any.whl (507 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m507.1/507.1 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting evaluate\n",
" Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: transformers[sentencepiece] in /usr/local/lib/python3.10/dist-packages (4.35.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (10.0.1)\n",
"Collecting pyarrow-hotfix (from datasets)\n",
" Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n",
"Collecting dill<0.3.8,>=0.3.0 (from datasets)\n",
" Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
"Collecting multiprocess (from datasets)\n",
" Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.1)\n",
"Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.19.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
"Collecting responses<0.19 (from evaluate)\n",
" Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (2023.6.3)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (0.15.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (0.4.1)\n",
"Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[sentencepiece])\n",
" Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (3.20.3)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.4->datasets) (4.5.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.11.17)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
"Installing collected packages: sentencepiece, pyarrow-hotfix, dill, responses, multiprocess, datasets, evaluate\n",
"Successfully installed datasets-2.16.0 dill-0.3.7 evaluate-0.4.1 multiprocess-0.70.15 pyarrow-hotfix-0.6 responses-0.18.0 sentencepiece-0.1.99\n",
"Collecting accelerate\n",
" Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.23.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu121)\n",
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.19.4)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.5.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2023.6.0)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2023.11.17)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n",
"Installing collected packages: accelerate\n",
"Successfully installed accelerate-0.25.0\n",
"\u001b[31mERROR: torch_xla-1.9-cp37-cp37m-linux_x86_64.whl is not a supported wheel on this platform.\u001b[0m\u001b[31m\n",
"Reading package lists... Done\n",
"Building dependency tree... Done\n",
"Reading state information... Done\n",
"git-lfs is already the newest version (3.0.2-1ubuntu0.2).\n",
"0 upgraded, 0 newly installed, 0 to remove and 24 not upgraded.\n"
]
}
],
"source": [
"!pip install datasets evaluate transformers[sentencepiece]\n",
"!pip install accelerate\n",
"# To run the training on TPU, you will need to uncomment the following line:\n",
"!pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
"!apt install git-lfs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uyOjlt8igJYD"
},
"source": [
"You will need to setup git, adapt your email and name in the following cell."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "GS9z71W1gJYD"
},
"outputs": [],
"source": [
"!git config --global user.email \"koakande@gmail.com\"\n",
"!git config --global user.name \"koakande\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tk9JVa10gJYE"
},
"source": [
"You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 331,
"referenced_widgets": [
"b09ce087c3134d64839edbc4a568f742",
"43aa4e04178e4b5fa50837e4e0db4d55",
"ebd1f7f7f4154bca9f0594a2f3351472",
"768816c3f54c4e6cbca05c47b1ec34ad",
"f18137c6c2f04783a50580de5e5caf3b",
"c10791b3c8b3483d8a97007797392ab4",
"b17ee1d8d8374223b2bb160758cf2355",
"847822b699924086a19f2b9e73e613a2",
"d068c6a78af848ddaf388fba271e28f7",
"7fee36f9922c49c083d109fb49f48737",
"d499f0acba1842e0892b1e0766e6c6c0",
"b93985358ab046cd942e6249ca460539",
"7d67a4d67b494ceeb6f7d2c7791b26ef",
"368077ddf45840a89447a2f8ee5af94e",
"2df808b55601428c96e693a6f0d70e52",
"c154e37f61324f69b7f99dfaded33e96",
"8830863aa8da4402b6ae48678c23a8c8"
]
},
"id": "2_b6ALK3gJYE",
"outputId": "3adce3a2-adb4-4a5b-9b8d-2927776678fd"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"VBox(children=(HTML(value='