{ "cells": [ { "cell_type": "markdown", "id": "a8b6caed", "metadata": {}, "source": [ "# πŸ‡ͺπŸ‡Ί 🏷️ Eurovoc Model Training Notebook" ] }, { "cell_type": "code", "execution_count": 1, "id": "c4c73793", "metadata": {}, "outputs": [], "source": [ "import pickle \n", "import pandas as pd\n", "from transformers import AutoTokenizer, AutoModel\n", "\n", "from datasets import list_datasets, load_dataset\n", "\n", "from sklearn.preprocessing import MultiLabelBinarizer\n", "import torch\n", "\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint" ] }, { "cell_type": "markdown", "id": "dc770f0b", "metadata": { "tags": [] }, "source": [ "---\n", "\n", "## 1. Data loading\n", "### Choose our dataset, extracted from ep registry or eurlex57k" ] }, { "cell_type": "code", "execution_count": 2, "id": "9fdc5328", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset json (/home/scampion/.cache/huggingface/datasets/EuropeanParliament___json/EuropeanParliament--cellar_eurovoc-3a27a019ebbf0296/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5bf91bf9dc2416faefe96d680217da6", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
titledateeurovoc_conceptsurllangformatstext
0Corrigendum to Commission Implementing Regulat...2023-07-20[China, Malaysia, anti-dumping duty, business ...http://publications.europa.eu/resource/cellar/...eng[fmx4, pdfa2a, xhtml]L_2023183EN. 01005801. xml 20. 7. 2023Β Β Β  EN O...
1Council Decision (CFSP) 2023/1501 of 20Β July 2...2023-07-20[EU restrictive measure, Russia, Ukraine, econ...http://publications.europa.eu/resource/cellar/...eng[fmx4, pdfa2a, xhtml]LI2023183EN. 01004801. xml 20. 7. 2023Β Β Β  EN O...
2Council Decision (CFSP) 2023/1502 of 20Β July 2...2023-07-20[Burma/Myanmar, EU restrictive measure, econom...http://publications.europa.eu/resource/cellar/...eng[fmx4, pdfa2a, xhtml]LI2023183EN. 01005201. xml 20. 7. 2023Β Β Β  EN O...
3The Committee of the Regions welcomes Croatian...2023-07-20[Croatia, EU regional policy, European Committ...http://publications.europa.eu/resource/cellar/...eng[pdf]EUROPEAN UNION Committee of the Regions The Co...
4Corrigendum to Commission Implementing Regulat...2023-07-20[India, TΓΌrkiye, anti-dumping duty, building m...http://publications.europa.eu/resource/cellar/...eng[fmx4, pdfa2a, xhtml]L_2023183EN. 01005901. xml 20. 7. 2023Β Β Β  EN O...
\n", "" ], "text/plain": [ " title date \\\n", "0 Corrigendum to Commission Implementing Regulat... 2023-07-20 \n", "1 Council Decision (CFSP) 2023/1501 of 20Β July 2... 2023-07-20 \n", "2 Council Decision (CFSP) 2023/1502 of 20Β July 2... 2023-07-20 \n", "3 The Committee of the Regions welcomes Croatian... 2023-07-20 \n", "4 Corrigendum to Commission Implementing Regulat... 2023-07-20 \n", "\n", " eurovoc_concepts \\\n", "0 [China, Malaysia, anti-dumping duty, business ... \n", "1 [EU restrictive measure, Russia, Ukraine, econ... \n", "2 [Burma/Myanmar, EU restrictive measure, econom... \n", "3 [Croatia, EU regional policy, European Committ... \n", "4 [India, TΓΌrkiye, anti-dumping duty, building m... \n", "\n", " url lang \\\n", "0 http://publications.europa.eu/resource/cellar/... eng \n", "1 http://publications.europa.eu/resource/cellar/... eng \n", "2 http://publications.europa.eu/resource/cellar/... eng \n", "3 http://publications.europa.eu/resource/cellar/... eng \n", "4 http://publications.europa.eu/resource/cellar/... eng \n", "\n", " formats text \n", "0 [fmx4, pdfa2a, xhtml] L_2023183EN. 01005801. xml 20. 7. 2023Β Β Β  EN O... \n", "1 [fmx4, pdfa2a, xhtml] LI2023183EN. 01004801. xml 20. 7. 2023Β Β Β  EN O... \n", "2 [fmx4, pdfa2a, xhtml] LI2023183EN. 01005201. xml 20. 7. 2023Β Β Β  EN O... \n", "3 [pdf] EUROPEAN UNION Committee of the Regions The Co... \n", "4 [fmx4, pdfa2a, xhtml] L_2023183EN. 01005901. xml 20. 7. 2023Β Β Β  EN O... " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train = dataset['train'].to_pandas()\n", "test = dataset['test'].to_pandas() if 'test' in dataset.keys() else None\n", "validation = dataset['validation'].to_pandas() if 'validation' in dataset.keys() else None\n", "\n", "all = pd.concat([train, test, validation])#[:1000]\n", "all.head()" ] }, { "cell_type": "code", "execution_count": 4, "id": "4c141dfa", "metadata": {}, "outputs": [], "source": [ "#all['eurovoc_concepts_str'] = all['eurovoc_concepts'].apply(str)" ] }, { "cell_type": "markdown", "id": "aeca89c2", "metadata": { "tags": [] }, "source": [ "### Create the MultiLabel Binarizer and save it in a file for prediction" ] }, { "cell_type": "code", "execution_count": 4, "id": "d6846099", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('Number of classes', 6835)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlb = MultiLabelBinarizer().fit(all['eurovoc_concepts'])\n", "\n", "pickle.dump(mlb, open('mlb.pickle', 'wb'))\n", "\"Number of classes\", len(mlb.classes_)" ] }, { "cell_type": "markdown", "id": "1f27b865", "metadata": { "tags": [] }, "source": [ "---\n", "## 2. Split data using iterative train test " ] }, { "cell_type": "code", "execution_count": null, "id": "ba290237", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "#X = np.array(all['text'].to_list())\n", "#X = np.expand_dims(X, axis=1)\n", "X = all['text'].to_numpy()\n", "X = np.expand_dims(X, axis=1)\n", "y = mlb.transform(all['eurovoc_concepts'])\n", "\n", "\n", "from skmultilearn.model_selection import iterative_train_test_split\n", "x_tr, y_tr, x_test, y_test = iterative_train_test_split(X, y, test_size = 0.1)\n", "x_tr, y_tr, x_val, y_val = iterative_train_test_split(x_tr, y_tr, test_size = 0.1)\n", "len(x_tr), len(x_val), len(x_test)" ] }, { "cell_type": "code", "execution_count": null, "id": "98371ad3", "metadata": {}, "outputs": [], "source": [ "# Example \n", "i = 10\n", "x_tr[i][0][0:120], mlb.inverse_transform(np.expand_dims(y_tr[i], axis=1).T)" ] }, { "cell_type": "markdown", "id": "7c959b6a", "metadata": {}, "source": [ "---\n", "## 3. Model definition and training" ] }, { "cell_type": "code", "execution_count": null, "id": "a177f1ce", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "f4061399", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n", "\n", " | Name | Type | Params\n", "------------------------------------------\n", "0 | bert | BertModel | 109 M \n", "1 | dropout | Dropout | 0 \n", "2 | classifier1 | Linear | 5.1 M \n", "3 | criterion | BCELoss | 0 \n", "------------------------------------------\n", "114 M Trainable params\n", "0 Non-trainable params\n", "114 M Total params\n", "458.304 Total estimated model params size (MB)\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n" ] } ], "source": [ "%%capture output\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from eurovoc import EurovocTagger, EurovocDataset, EurovocDataModule\n", "\n", "\n", "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n", "N_EPOCHS = 30\n", "BATCH_SIZE = 10\n", "MAX_LEN = 512\n", "LR = 5e-05\n", "\n", "\n", "# Instantiate and set up the data_module\n", "dataloader = EurovocDataModule(BERT_MODEL_NAME, x_tr, y_tr, x_val, y_val , BATCH_SIZE, MAX_LEN)\n", "dataloader.setup()\n", "\n", "\n", "model = EurovocTagger(BERT_MODEL_NAME, len(mlb.classes_), lr=LR)\n", "\n", "checkpoint_callback = ModelCheckpoint(\n", " monitor='val_loss',\n", " filename='EurovocTagger-{epoch:02d}-{val_loss:.2f}',\n", " mode='min',\n", ")\n", "\n", "trainer = pl.Trainer(max_epochs=N_EPOCHS , accelerator=\"gpu\", devices=1, callbacks=[checkpoint_callback])#,strategy=\"ddp_notebook\")\n", "trainer.fit(model, dataloader)" ] }, { "cell_type": "code", "execution_count": 13, "id": "19084e69", "metadata": {}, "outputs": [], "source": [ "trainer.save_checkpoint(\"eurovoc_cellar.ckpt\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d8289db5", "metadata": {}, "outputs": [], "source": [ "output()" ] }, { "cell_type": "code", "execution_count": 14, "id": "7c250c40", "metadata": {}, "outputs": [], "source": [ "np.save('x_test', x_test)\n", "np.save('y_test', y_test)" ] }, { "cell_type": "code", "execution_count": 15, "id": "418a7fd0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/scampion/training/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:148: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n", " rank_zero_warn(\n", "Restoring states from the checkpoint path at /home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n", "Loaded model weights from the checkpoint at /home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt\n", "/home/scampion/training/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "52fdd2fcc27744c4955dc449cc126100", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Testing: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃        Test metric        ┃       DataLoader 0        ┃\n",
       "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "β”‚         test_loss         β”‚   0.0031269278842955828   β”‚\n",
       "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "β”‚\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.0031269278842955828 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[{'test_loss': 0.0031269278842955828}]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.test(dataloaders=dataloader)" ] }, { "cell_type": "markdown", "id": "66b871ec", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "code", "execution_count": 16, "id": "ba317c3e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model_path = trainer.checkpoint_callback.best_model_path\n", "best_model_path" ] }, { "cell_type": "code", "execution_count": 17, "id": "fe9751a1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 23243/23243 [16:20<00:00, 23.72it/s] \n" ] } ], "source": [ "from tqdm import tqdm\n", "from transformers import AutoTokenizer\n", "\n", "trained_model = EurovocTagger.load_from_checkpoint(best_model_path,\n", " bert_model_name=BERT_MODEL_NAME,\n", " n_classes=len(mlb.classes_))\n", "trained_model.eval()\n", "trained_model.freeze()\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "trained_model = trained_model.to(device)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)\n", "\n", "val_dataset = EurovocDataset(x_test, y_test, tokenizer, max_token_len=MAX_LEN)\n", "predictions = []\n", "labels = []\n", "\n", "for item in tqdm(val_dataset):\n", " _, prediction = trained_model(\n", " item[\"input_ids\"].unsqueeze(dim=0).to(device), \n", " item[\"attention_mask\"].unsqueeze(dim=0).to(device)\n", " )\n", " predictions.append(prediction.flatten())\n", " labels.append(item[\"labels\"].int())\n", "\n", "predictions = torch.stack(predictions).detach().cpu()\n", "labels = torch.stack(labels).detach().cpu()" ] }, { "cell_type": "markdown", "id": "67477f7f", "metadata": {}, "source": [ "### F1 Score" ] }, { "cell_type": "code", "execution_count": 18, "id": "f0265f6e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.01 tensor(0.2188)\n", "0.06 tensor(0.3929)\n", "0.11 tensor(0.4353)\n", "0.16 tensor(0.4462)\n", "0.21 tensor(0.4437)\n", "0.26 tensor(0.4364)\n", "0.31 tensor(0.4249)\n", "0.36 tensor(0.4106)\n", "0.41 tensor(0.3947)\n", "0.46 tensor(0.3780)\n", "0.51 tensor(0.3597)\n", "0.56 tensor(0.3404)\n", "0.61 tensor(0.3209)\n", "0.66 tensor(0.3007)\n" ] } ], "source": [ "from torchmetrics import F1Score\n", "for i in range(1, 70, 5):\n", " f1 = F1Score(task=\"multilabel\", num_labels=len(mlb.classes_), average='weighted', threshold= i / 100.0)\n", " print(i / 100.0, f1(predictions, labels))" ] }, { "cell_type": "markdown", "id": "0945ad49", "metadata": {}, "source": [ "### NDCG Score" ] }, { "cell_type": "code", "execution_count": null, "id": "e4e3291f", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import ndcg_score\n", "def calculate_average_ndcg(predictions, labels, top_k=5):\n", " # Initialize a list to store NDCG scores for each sample\n", " ndcg_scores = []\n", "\n", " # Calculate NDCG for each sample\n", " for i in range(len(predictions)):\n", " # Convert tensors to numpy arrays\n", " y_true = labels[i].cpu().numpy().reshape(1, -1)\n", " y_score = predictions[i].cpu().numpy().reshape(1, -1)\n", " \n", " # Calculate NDCG for the sample\n", " ndcg = ndcg_score(y_true, y_score, k=top_k)\n", " ndcg_scores.append(ndcg)\n", "\n", " # Calculate the average NDCG score\n", " average_ndcg = np.mean(ndcg_scores)\n", " \n", " return average_ndcg\n", "\n", "for k in [3, 5, 10]:\n", " average = calculate_average_ndcg(predictions, labels, top_k=k)\n", " print(\"NDCG@\"+str(k)+\": \"+ str(round(average, 4)))" ] } ], "metadata": { "kernelspec": { "display_name": "eurovoc-env", "language": "python", "name": "eurovoc-env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }