{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "nVjPisIdGIyJ" }, "source": [ "# Fine-tuning a model with the Trainer API or Keras" ] }, { "cell_type": "markdown", "metadata": { "id": "OFJLz44OGIyM" }, "source": [ "Install the Transformers, Datasets, and Evaluate libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "ZcjVJmbfGIyS", "outputId": "01e8505d-b1c7-4e6a-fade-c478b4c2641e", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting datasets\n", " Downloading datasets-2.14.4-py3-none-any.whl (519 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting evaluate\n", " Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.4/81.4 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting transformers[sentencepiece]\n", " Downloading transformers-4.32.1-py3-none-any.whl (7.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m29.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m24.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m18.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", "Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)\n", " Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m31.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Collecting responses<0.19 (from evaluate)\n", " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (3.12.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (2023.6.3)\n", "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers[sentencepiece])\n", " Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m72.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting safetensors>=0.3.1 (from transformers[sentencepiece])\n", " Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m71.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[sentencepiece])\n", " Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m72.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (3.20.3)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.7.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", "Installing collected packages: tokenizers, sentencepiece, safetensors, xxhash, dill, responses, multiprocess, huggingface-hub, transformers, datasets, evaluate\n", "Successfully installed datasets-2.14.4 dill-0.3.7 evaluate-0.4.0 huggingface-hub-0.16.4 multiprocess-0.70.15 responses-0.18.0 safetensors-0.3.3 sentencepiece-0.1.99 tokenizers-0.13.3 transformers-4.32.1 xxhash-3.3.0\n" ] } ], "source": [ "!pip install datasets evaluate transformers[sentencepiece]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "F-oF8ZZfGIyU", "outputId": "e7b67cec-f205-4356-dc8b-a0d58ecd243c", "colab": { "base_uri": "https://localhost:8080/", "height": 706, "referenced_widgets": [ "c7704eab83c14291a96fba565e7cd04a", "5d72cb50cfbd4df0a0d17413f32058fd", "e0920d1e5f554a99bb77d2a6a15449c2", "92626eb1128a4e14a9a4a77aa750a744", "fb0163aa25ea4a72bd951fc57effcc95", "85529dd1f58b4becbaf6017cf05ebe6c", "bb0c28cf8e16496f914c55bc26015e52", "422c22493fa44be09f5ea5b4930ae7ae", "84ed6c677756476d94db89034d539056", "af23fa1181e747ecb6ee8717855b8335", "4ef0a728b1104c6daba5e38daaa7ecfc", "907bd0473b1d42b0bb7d1e45a3ff8695", "5c76adb90c2d42e78dfea3c0357d7a65", "a2753cc473844b5ca1dd601b320c724f", "a09fcf23a3f04aed8051c5a49c00dd68", "4c9eaf22ddb24bdf97a9c4fc17af126f", "b23817253f4c4b55ae774b67cea49c9a", "ea7fdd5c08244e3b9288cc3037839257", "cbf4925760484ec89473ad8441f1578c", "6a0f265cb76d4839a9d3e9bb6b4ba779", "7f25e309decc4a039514232d464923a8", "0c27881cdc9245ba989a392013981d14", "1123a74a8ae849a2a198ad27003951d2", "49edecb5acd1405a90dad9c9231791ad", "f556248340ef491abfaf4d542b062b83", "b3402a68a42e438c9a3d4a93c6f88b7c", "d56033e133b14f8e916505a20aeff98d", "fff5d59be014474aadb0aa2dfebea1b6", "fdb1407055c943a4b1c29d406e2d9750", "b6ab0a5ab0e34f54b3a8722316a2c3f9", "ba961cb840a845b0940f113763343e10", "8143bb14c19d48e9905342a31d7cf0bf", "12341bb941204cb698723217814a7159", "64da9711f40b4743877f4c5559cf6c25", "5f3745098d3d4ca3ac7de1a0aab8768f", "8ab8c3bf014c4647bb4c29dd53cfac38", "accc801cf80c41b4a0a5299f721ebf3f", "518b0d2302244df6821a259917a3b8cf", "6b6cf526ff844351b7d86ab75ae4fe64", "364ee93350c34e979fc79c59f0ca9663", "0f660512fb68401fbaa6ebf8f5c43e06", "86606bc1ea024355b61a64733b6d886e", "86288186fe29427bafa73e4c8bd607a7", "7cc0736091ae49fdb294db62668ffafe", "4b78560336b34813b39f884c8dad7341", "958dfa3df6a34a7ca7e2eb03787590f7", "6e73045cd53b472ab417754ef03bbc50", "a2ffb7c91fd047b7a9b42230fe917ef4", "e08096883ffe47ff8c78f774dc98d59f", "a5d7925c7f3846b19699060eb5b7455a", "6727accb70e1462a8137fc021f40ce5e", "d4c396bceb374b4f9b09bf654ee4b90f", "7d6a347392824af8aa53fed7aceb4e0d", "10a886b5425e47f1934859f8c8c2aa2f", "40f018152a304e9cb4f1aa4d014aa2cc", "3a0ca9bf963646f98a359db0f8dc9e26", "b9ae14f4291147d0a8d3aa132754240e", "09b1f96bf03c4fe0b6ff93c237c1c845", "a87c35c08df84c3594441de4b8fe97f7", "61852c374f844632b2d1fe3f1ab34978", "91610702cee349cf8b54b3971e9256e4", "e6aabcb6951445e7ace7f6aa6612dd8a", "dd37a945c092467db34c0859171250cf", "179e4746bc3a43e381dc6c92ceed4694", "2ab0b3ad6c2243a9823d6719da0169cf", "46f4c0089b474de4abbb67d51e330186", "f13c0e9b23ae4dc09c0012ecdb93c279", "044cd39d00d242388f45ae11787872da", "f5e5c3945d424f558cd1cc3e2001e40d", "3633430e228d46e5b4e85370643e3bbd", "e7d5f74597034394acdcdd0f0073c0db", "f22fd67ad64b4b8694ebd3ecdaea0d89", "5a75df9a603146ae8f25bcb503791bd5", "c3f19076b50445dc88a624d451b15a49", "620d66eb3d7c46278c8b6d79953f2cb7", "1fcb3edd9ce64a5bab047f92b7f3876c", "5b40575d18ff4e7ea77086be5185d465", "f4b6a3e8111f4505ada84b37ef4634da", "17ca717d89384d16acdd83e8384b1820", "51b3818b51844b58b72f0821cb46a7fa", "ee3348afca4f42b58f441a36fc965242", "d0c88e9111f740aeb2ad589e3f921683", "7dd5c363c1bc4aeaa3d50d6724c7bad2", "df00b8a089c34c5983bde8a8b0cc3c96", "a25c5e38e84c4773891e066083884394", "23812f4e94a544ca93f347ad37e5440e", "0ea575ec00244798aca51a9459474fce", "94ed80e1a8724ed3b65a046ff4920fb6", "f03bc122e2724c8fa3bf3eb72cd88c4d", "892825466b454e779d17ec4b747cafba", "6e483ee2311746978fd71526ea6a5906", "cdbffba61c074bf680d1aa1010bec13b", "70a8b6ed55cc44cca8e7ea325267cefc", "6ef7286d7dbe462aaae1b23453705dd0", "056565fb1d644c01a5976de7b64a34d6", "d2a775433c134bf1bc64ef128febf55e", "04bc573aa84441dca328c6cd34f5be34", "938fd609aa9e4c9b8117ffcab1f3568c", "52e1d33bc5cb4f6e964e88967f8df72a", "079c840eab134f16b5639236de1aa4e6", "74e0c276cd69431095f5329853043cb5", "f0f1688327384b91a431b82335514460", "0e162ae466b2409691ddae62c32d93ff", "cd160e0ce94a48ee8266ba2531f966d6", "49ab2180680e401ba350e5f218820cda", "efc402d898fa420fa13c48a1ac73e958", "3979ec25baa3470d8b4997babe8bd77f", "1b186a7537c3429e8e1d559807e15560", "52efeda7368347f29be9dc500c88d82d", "beac9814c3b4448fb2a51d4568544844", "00cf0e039d1448d9935623bfcecfcce9", "56aa6bc93efe45b1bd6e6bbb1e449754", "efd1403715c443ae9045b43a036f9b1c", "93123868e43348afb23c1dcc4aaa693f", "86c4e082cfef4c6b839823eddb586d3d", "21f2271e361d46fda9368ab11019f9e0", "2b00c47b85794664a380e9ec3b46046c", "7f51a22ac82d408c9de6747025c69321", "5e29c81c50944d898b44d2b2a78410cd", "6fdb0da64b7e4da4a6bc0f68d64cc59b", "13887843690047eaa56083fd11cad808", "776c188018dd4399a6c815c3139fcd52", "bae18a3b63c14beca3ebd51042bda076", "70ebfd0e75544888b5696dce92a2d716", "f3756ba091934ff78c9d697621b108e6", "a01534da4b994680be8bd00c80ebfe57", "f854b8e5fb794e8da3299b6b68415da0", "90af3d04c2ac4f0882d491061ddebfef", "fa4ca84b4ff94fb586cd7032ae390e39", "02d667ff4c1041ae8069f21af31390d7", "6803181d5558410ba82ff14ae62ef1d3", "c27e8d22075841cfa8c2aeb700430026", "e9c8df7000fe444eae700896c0a57b23", "39aeee4df4ef42ea86d44cda292ceeec", "f37d34fef8c248b89f9c853d79601561", "8eead9e8b08f429aa1d62c74f789333a", "3f8804ed3ddb415e817722283a79d3e3", "b885858b5faf46fcaba5ca50252239b4", "a2bad55392ae4caea91833352be694ab", "32817222e644404f90de4a19668c0ced", "2bc1cec0bb2340788c8f80899e2bfe3a", "74ea8032079944d0a21e8414b67cddad", "c8e03543093f4a259b8d93a218fa2744", "e9ead4bf2fe24b8091d465fde312d323", "ace4454075f4479e94738e84992ad4e8", "26b908494ef246d8a26a7e9bc57aa06e", "15a86fa989af4de586ebf8f43383f795", "aad0fb84e1a4421cab2057d8fe189248", "77072af8c0b643ffb4166b6ab303ee1a", "992eef239510435392d416fee4d0bbb3", "f5c65d98a9a2486b8e7a5b0a3b677dc1", "33f55eff967f4d2cb0e771650d1f51e8", "c275f91445b6406c99da50114773fc28", "d3555fcab7e149df9db0af58ee64a583", "a7bb7c9ebaf2462e9ae0f4d14e00f8d3", "4398211a4d31401e8c2f6abc7cfa522d", "2e8c9a4a048f4c02b0366f4f235378cb", "4bb8c37292ac4729811fabf3ca62c9bd", "84dce524f2864a6782c6ee92a1528995", "9dc56b8c26bc4f65acce154679db2a74", "5a2453f796394a4f86f13f7acef44d3b", "8225e5b185c3403f947cc77d7277f7e8", "b49fcb9bc7cd48b781e96e902f7af98f", "24ec2ed18d9d4fe48a77b5c5a1eb11bf", "60ed44aba03240dabbaaf463bbb4c1b3", "9ae352713c244e1cb5201bd6997149c9", "38a131cc348f468a8b191f5d18b35c1f", "f94c91170cc74398a67a786bc8980b7a", "31d17163b347422ebcb99dcea6f75aa9", "272ec0ae6e734918aa914f242dd91382", "1d514c7a7c00423f9e6662ae9c9dc56c", "864869a9c7bc4a5cab7642c7988c29e5", "a7eb2345d0f84084ab5cf07173d8bdb7", "69a4d12d067f468ba3de6a69ae21b541", "226b68c70ffc4b16a4fe74d52ddddd0a", "d1a21c38558b47859323fbad7f3d94b1", "4591fb948f804f36b4abd4ea0599d63c", "0c66ef490a4c4077848d13304292ec77", "cda78e413afc486a8df80841498a3f84", "701396af90684374921ad044bfaadd1d", "ea4cf697ea4c4ed6903aed9c015bea0c", "8054b11f0770474795d2d58176bc1f04", "207d7b5bbdfe4673872cafb32d772cb6", "2d007a4d76f74015a588effab99fd63c", "25f30a25016e4a6db4cb1be8d4208820", "a47557eb9726448f8e759460461b64ab", "f991329200244bceb8e44537948c759a" ] } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading builder script: 0%| | 0.00/28.8k [00:00 (tf.Tensor, tf.Tensor) \n", " : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) \n", "New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor}) \n", " : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) \n", " warnings.warn(\n", "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] } ], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer, DataCollatorWithPadding\n", "import numpy as np\n", "\n", "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n", "checkpoint = \"bert-base-uncased\"\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "\n", "def tokenize_function(example):\n", " return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n", "\n", "\n", "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=\"tf\")\n", "\n", "tf_train_dataset = tokenized_datasets[\"train\"].to_tf_dataset(\n", " columns=[\"attention_mask\", \"input_ids\", \"token_type_ids\"],\n", " label_cols=[\"labels\"],\n", " shuffle=True,\n", " collate_fn=data_collator,\n", " batch_size=8,\n", ")\n", "\n", "tf_validation_dataset = tokenized_datasets[\"validation\"].to_tf_dataset(\n", " columns=[\"attention_mask\", \"input_ids\", \"token_type_ids\"],\n", " label_cols=[\"labels\"],\n", " shuffle=False,\n", " collate_fn=data_collator,\n", " batch_size=8,\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "3grvOO6GGIyY", "outputId": "27a8fe07-7e30-4438-8bb3-7398b77d8b49", "colab": { "base_uri": "https://localhost:8080/", "height": 140, "referenced_widgets": [ "21500f3ec6884f6e810b0e3773ddb381", "a2e410c2e89945fa81a46eb1c08364f3", "7dc2e77fa6684812ac7542248be9737d", "ef8c36bc263442d3938a95980b115345", "044a326c1e56443594572cd22b923f8c", "9d9767d684ad483999d22948e343f697", "9de9c8c1af9b4ca0acde7ff771685ca6", "bdf4880bff0b429dbf993c7eac566dba", "61b9dd31982a45d09a1377ca95e012ce", "eb148900c8dc4a5aacca889cbe1c2a9b", "578d4e69144b466095321228b4ece6ab" ] } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading model.safetensors: 0%| | 0.00/440M [00:00" ] }, "metadata": {}, "execution_count": 5 } ], "source": [ "from tensorflow.keras.losses import SparseCategoricalCrossentropy\n", "\n", "model.compile(\n", " optimizer=\"adam\",\n", " loss=SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"],\n", ")\n", "model.fit(\n", " tf_train_dataset,\n", " validation_data=tf_validation_dataset,\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "vt7L9YHmGIyb" }, "outputs": [], "source": [ "from tensorflow.keras.optimizers.schedules import PolynomialDecay\n", "\n", "batch_size = 8\n", "num_epochs = 3\n", "# The number of training steps is the number of samples in the dataset, divided by the batch size then multiplied\n", "# by the total number of epochs. Note that the tf_train_dataset here is a batched tf.data.Dataset,\n", "# not the original Hugging Face Dataset, so its len() is already num_samples // batch_size.\n", "num_train_steps = len(tf_train_dataset) * num_epochs\n", "lr_scheduler = PolynomialDecay(\n", " initial_learning_rate=5e-5, end_learning_rate=0.0, decay_steps=num_train_steps\n", ")\n", "from tensorflow.keras.optimizers import Adam\n", "\n", "opt = Adam(learning_rate=lr_scheduler)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "DuAvsWL-GIyc", "outputId": "6ffa8e92-3421-4ba1-bf72-d4d02a91d383", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "All PyTorch model weights were used when initializing TFBertForSequenceClassification.\n", "\n", "Some weights or buffers of the TF 2.0 model TFBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n", "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "model.compile(optimizer=opt, loss=loss, metrics=[\"accuracy\"])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "QvM4Fzp1GIyd", "outputId": "adabe2d7-a64c-4310-f1bb-d6974384ea35", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/3\n", "459/459 [==============================] - 147s 202ms/step - loss: 0.6188 - accuracy: 0.6824 - val_loss: 0.5958 - val_accuracy: 0.7059\n", "Epoch 2/3\n", "459/459 [==============================] - 74s 161ms/step - loss: 0.4757 - accuracy: 0.7759 - val_loss: 0.3560 - val_accuracy: 0.8456\n", "Epoch 3/3\n", "459/459 [==============================] - 71s 154ms/step - loss: 0.2150 - accuracy: 0.9248 - val_loss: 0.3890 - val_accuracy: 0.8652\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 8 } ], "source": [ "model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "hffkTuU1GIye", "outputId": "23017617-46a9-4af4-dea9-83d3c242b993", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "51/51 [==============================] - 7s 55ms/step\n" ] } ], "source": [ "preds = model.predict(tf_validation_dataset)[\"logits\"]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "8evGL8NVGIyf", "outputId": "a7b360bc-a12e-4e16-902e-89e75de273a9", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "(408, 2) (408,)\n" ] } ], "source": [ "class_preds = np.argmax(preds, axis=1)\n", "print(preds.shape, class_preds.shape)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "VTJ6smIdGIyg", "outputId": "a325bace-b579-412e-d9cc-7bd06c887ae1", "colab": { "base_uri": "https://localhost:8080/", "height": 67, "referenced_widgets": [ "9d83e21ecd2a482c8ec8f756f24a22e9", "aad6ca73c1094feca26b32d9168e8f88", "55806a488dfb4dafb9506d35b6aa7c40", "3256ca8526c64f05bccba0e786796e4e", "c96861fb8e6141bea479065c9fb63471", "e03a95bf11224856adfd8a03321afa59", "a40dfd001a86465997ed0e289202db00", "564d62d5aace4a1abe922c3bee29aed4", "4c09ac283583457c912cbaa31f0d943e", "e62aff5462c9417d996c8967c20cd490", "0a36364a646d4350b33b5f761ff5e8bc" ] } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading builder script: 0%| | 0.00/5.75k [00:00