{ "cells": [ { "cell_type": "markdown", "id": "1c71aba7-c0f3-4378-9b63-55529e0994b4", "metadata": {}, "source": [ "# Data\n", "\n", "Мы используем следующий датасет для файнтюнинга:\n", "\n", "- [arXiv papers](https://www.kaggle.com/datasets/neelshah18/arxivdataset)\n", "\n", "Среди статей на arXiv есть также статьи по вычислительной биологии, геномике, etc.\n", "\n", "Среди альтернатив — [датасет](https://zenodo.org/record/7695390) из [недавнего исследования](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1.full.pdf) с названиями и лейблами статей из PubMed. В нём 20 миллионов статей, но приведены только заголовки (без абстрактов).\n", "\n", "В данном ноутбуке мы используем данные и теги с arXiv." ] }, { "cell_type": "markdown", "id": "e9874f4a-3898-4c89-a0f7-04eeabf2b389", "metadata": { "tags": [] }, "source": [ "# Models\n", "\n", "В качестве базовой модели мы используем BERT, натренированный на биомедицинских данных (из PubMed). \n", "\n", "- [BiomedNLP-PubMedBERT](https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)" ] }, { "cell_type": "markdown", "id": "991e48e7-897f-45a3-8a0b-539ea67b4eb5", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "2f130f05-21ee-46f9-889f-488e8c676aba", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "757a0582-1b8c-4f1c-b26f-544688e391f4", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "import transformers\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "import torch\n", "from datasets import Dataset, ClassLabel\n", "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification\n", "from transformers import TrainingArguments, Trainer\n", "from transformers import pipeline\n", "import evaluate" ] }, { "cell_type": "markdown", "id": "03847b87-d096-49a5-b6e2-023fa08b94c2", "metadata": {}, "source": [ "# Load data" ] }, { "cell_type": "markdown", "id": "b3e902ea-4e0f-4d76-b27b-59e472b2b556", "metadata": {}, "source": [ "Загрузим данные для файнтюнинга — в частности, нам понадобятся названия статей, их абстракты и теги." ] }, { "cell_type": "code", "execution_count": 2, "id": "1be8f69e-bd7d-4ca9-ba9f-044b8e7bc497", "metadata": { "tags": [] }, "outputs": [], "source": [ "df = pd.read_json(\"arxivData.json\")" ] }, { "cell_type": "markdown", "id": "791edb3c-a96d-4042-b35d-c8097bbbef79", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "d5b6158a-728e-4ada-bcdc-a4a49328f002", "metadata": {}, "source": [ "Совместим заголовки и абстракты и сохраним текст в соответствующей колонке:" ] }, { "cell_type": "code", "execution_count": 3, "id": "c8709a7b-becf-4f19-8b4f-8773cd5c60f1", "metadata": { "tags": [] }, "outputs": [], "source": [ "df['text'] = df['title'] + \"\\n\" + df['summary']" ] }, { "cell_type": "code", "execution_count": 4, "id": "ed0ed687-6439-494a-a5a8-c572bc2e4059", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
authordayidlinkmonthsummarytagtitleyeartext
0[{'name': 'Ahmed Osman'}, {'name': 'Wojciech S...11802.00209v1[{'rel': 'alternate', 'href': 'http://arxiv.or...2We propose an architecture for VQA which utili...[{'term': 'cs.AI', 'scheme': 'http://arxiv.org...Dual Recurrent Attention Units for Visual Ques...2018Dual Recurrent Attention Units for Visual Ques...
1[{'name': 'Ji Young Lee'}, {'name': 'Franck De...121603.03827v1[{'rel': 'alternate', 'href': 'http://arxiv.or...3Recent approaches based on artificial neural n...[{'term': 'cs.CL', 'scheme': 'http://arxiv.org...Sequential Short-Text Classification with Recu...2016Sequential Short-Text Classification with Recu...
\n", "
" ], "text/plain": [ " author day id \\\n", "0 [{'name': 'Ahmed Osman'}, {'name': 'Wojciech S... 1 1802.00209v1 \n", "1 [{'name': 'Ji Young Lee'}, {'name': 'Franck De... 12 1603.03827v1 \n", "\n", " link month \\\n", "0 [{'rel': 'alternate', 'href': 'http://arxiv.or... 2 \n", "1 [{'rel': 'alternate', 'href': 'http://arxiv.or... 3 \n", "\n", " summary \\\n", "0 We propose an architecture for VQA which utili... \n", "1 Recent approaches based on artificial neural n... \n", "\n", " tag \\\n", "0 [{'term': 'cs.AI', 'scheme': 'http://arxiv.org... \n", "1 [{'term': 'cs.CL', 'scheme': 'http://arxiv.org... \n", "\n", " title year \\\n", "0 Dual Recurrent Attention Units for Visual Ques... 2018 \n", "1 Sequential Short-Text Classification with Recu... 2016 \n", "\n", " text \n", "0 Dual Recurrent Attention Units for Visual Ques... \n", "1 Sequential Short-Text Classification with Recu... " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head(2)" ] }, { "cell_type": "markdown", "id": "ce1de806-a4d2-4e58-a3a8-f3542392f22e", "metadata": {}, "source": [ "## Labels" ] }, { "cell_type": "markdown", "id": "b5183517-8b02-47bc-812a-415b5651e07d", "metadata": {}, "source": [ "Будем использовать категории из arXiv'а, такие как `astro-ph` для статей по астрофизике или `cs.CV` для computer vision (computer science)." ] }, { "cell_type": "code", "execution_count": 5, "id": "ba4e7197-23b6-4cb4-9b44-620c6b730eb7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total: 126 labels such as adap-org, astro-ph, ..., stat.OT\n" ] } ], "source": [ "df['category'] = [eval(i)[0]['term'].strip() for i in df['tag']]\n", "categories = np.unique(df['category'])\n", "num_labels = len(categories)\n", "print(f\"Total: {num_labels} labels such as {categories[0]}, {categories[1]}, ..., {categories[-1]}\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "1508a6d9-856d-4ecf-a0f3-895d3ffbe99b", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorycategory_index
0adap-org0
1astro-ph1
2astro-ph.CO2
3astro-ph.EP3
4astro-ph.GA4
\n", "
" ], "text/plain": [ " category category_index\n", "0 adap-org 0\n", "1 astro-ph 1\n", "2 astro-ph.CO 2\n", "3 astro-ph.EP 3\n", "4 astro-ph.GA 4" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame({\n", " \"category\": categories,\n", " \"category_index\": np.arange(num_labels),\n", "}).head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "5c082c3a-7b0e-4320-b62d-f75a6c9f2398", "metadata": { "tags": [] }, "outputs": [], "source": [ "df = pd.DataFrame({\n", " \"category\": categories,\n", " \"category_index\": np.arange(num_labels),\n", "}).set_index(\"category\").join(df.set_index(\"category\"), how=\"right\", sort=False).reset_index()" ] }, { "cell_type": "markdown", "id": "76d8ccb9-a993-4d82-9dd3-689380e92e55", "metadata": {}, "source": [ "# Model" ] }, { "cell_type": "code", "execution_count": 8, "id": "a0c154f7-d2fa-46a1-8b69-57174bf00632", "metadata": { "tags": [] }, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ] }, { "cell_type": "markdown", "id": "2bf6513d-664d-4b94-8b05-7e8df205e3ec", "metadata": {}, "source": [ "Токенайзер (название + абстракт -> токены):" ] }, { "cell_type": "code", "execution_count": 9, "id": "12fa49a7-2ac5-4f78-84fe-93305926692e", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")" ] }, { "cell_type": "markdown", "id": "0ea1b4e5-9067-4292-ba12-8f560bbf26fd", "metadata": {}, "source": [ "Сама модель, в которой `AutoModelForSequenceClassification` заменит голову для задачи классификации:" ] }, { "cell_type": "code", "execution_count": 10, "id": "d6eb92bc-c293-47ad-b9cc-2a63e8f1de69", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", num_labels=num_labels).to(device)" ] }, { "cell_type": "code", "execution_count": 11, "id": "f5c79846-e6fc-42c0-bb8d-949678f5e60a", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BertForSequenceClassification(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=126, bias=True)\n", ")\n" ] } ], "source": [ "print(model)" ] }, { "cell_type": "markdown", "id": "5ce6eefc-91ce-4486-9568-b686d04adcc7", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "markdown", "id": "71add72c-eafb-491a-8820-31ce7336524f", "metadata": {}, "source": [ "## Data Loaders" ] }, { "cell_type": "markdown", "id": "2a0b579c-998a-4d2e-bf0e-d4c7406d22da", "metadata": {}, "source": [ "Для работы с `transformers`, возможно, будет удобнее использовать библиотеку `datasets` для работы с данными." ] }, { "cell_type": "markdown", "id": "47b0e14a-866b-49ac-8b95-49a91a0bcc22", "metadata": {}, "source": [ "Создадим (hugging face) [датасет](https://huggingface.co/docs/datasets/tabular_load#pandas-dataframes):" ] }, { "cell_type": "code", "execution_count": 13, "id": "dc1a3f33-0ef9-43c9-ab5f-eb9ae304b897", "metadata": { "tags": [] }, "outputs": [], "source": [ "np.random.seed(42)\n", "train_indices = np.sort(np.random.choice(np.arange(len(df)), size=37_000, replace=False))\n", "test_indices = np.array([i for i in np.arange(len(df)) if i not in train_indices])" ] }, { "cell_type": "code", "execution_count": 14, "id": "d948f8a6-1a7a-4baa-88a0-418596a1f275", "metadata": { "tags": [] }, "outputs": [], "source": [ "train_df = df.loc[:,[\"text\", \"category\"]].iloc[train_indices]\n", "test_df = df.loc[:,[\"text\", \"category\"]].iloc[test_indices]\n", "\n", "train_ds = Dataset.from_pandas(train_df, split=\"train\")\n", "test_ds = Dataset.from_pandas(test_df, split=\"test\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "50242a35-3067-41e5-8de8-f7e6a4fb6e9c", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/37000 [00:00 main\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "trainer.push_to_hub()" ] }, { "cell_type": "markdown", "id": "5093aee3-106e-43e9-a9c7-413d059ebb27", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "b1a1029f-543c-409e-9aaf-35bcefe49988", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "markdown", "id": "e7b0cd5a-2e17-49f3-b2a9-5ae4e8511969", "metadata": {}, "source": [ "Теперь попробуем загрузить модель с HF Hub:" ] }, { "cell_type": "code", "execution_count": 2, "id": "b7fe37b9-61a9-4796-af24-092f6722cd61", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "36afc9d465f54c80ab01698f5a687388", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)okenizer_config.json: 0%| | 0.00/394 [00:00= threshold:\n", " break\n", "\n", " preds = preds[:(i+1)]\n", " \n", " return preds" ] }, { "cell_type": "code", "execution_count": 5, "id": "ed3545b6-e043-4dfb-aeb2-7559eac37f7c", "metadata": { "tags": [] }, "outputs": [], "source": [ "def format_predictions(preds) -> str:\n", " \"\"\"\n", " Prepare predictions and their scores for printing to the user\n", " \"\"\"\n", " out = \"\"\n", " for i, item in enumerate(preds):\n", " out += f\"{i+1}. {item['label']} (score {item['score']:.2f})\\n\"\n", " return out" ] }, { "cell_type": "code", "execution_count": 9, "id": "870d593a-a298-4d55-87b0-cb2813cc1fad", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1. cs.LG (score 0.88)\n", "2. cs.AI (score 0.07)\n", "3. cs.NE (score 0.03)\n", "\n" ] } ], "source": [ "print(\n", " format_predictions(\n", " top_pct(\n", " pipe(\"Attention Is All You Need\\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.\")[0]\n", " )\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "408f015e-be23-46a6-9e91-503fdccecf11", "metadata": {}, "source": [ " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }