{ "cells": [ { "cell_type": "markdown", "id": "1c71aba7-c0f3-4378-9b63-55529e0994b4", "metadata": {}, "source": [ "# Data\n", "\n", "Мы используем следующий датасет для файнтюнинга:\n", "\n", "- [датасет](https://zenodo.org/record/7695390) из [недавнего исследования](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1) с названиями и лейблами статей из PubMed. \n", "\n", "В нём 20 миллионов статей, но приведены только заголовки (без абстрактов — их можно дополнительно [получить](https://www.nlm.nih.gov/databases/download/pubmed_medline.html) по PMID статей). Файнтюнинг модели на таком объёме данных потребует определённых времени и вычислительных ресурсов (примерные затраты [приведены в статье](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1)), поэтому ниже мы воспользуемся упрощённым датасетом и будем тренировать только на заголовках статей." ] }, { "cell_type": "markdown", "id": "e9874f4a-3898-4c89-a0f7-04eeabf2b389", "metadata": { "tags": [] }, "source": [ "# Models\n", "\n", "В качестве базовой модели мы используем BERT, натренированный на биомедицинских данных (из PubMed). \n", "\n", "- [BiomedNLP-PubMedBERT](https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)" ] }, { "cell_type": "markdown", "id": "991e48e7-897f-45a3-8a0b-539ea67b4eb5", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "2f130f05-21ee-46f9-889f-488e8c676aba", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "757a0582-1b8c-4f1c-b26f-544688e391f4", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "import transformers\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "import torch\n", "from datasets import Dataset, ClassLabel\n", "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification\n", "from transformers import TrainingArguments, Trainer\n", "from transformers import pipeline\n", "import evaluate" ] }, { "cell_type": "markdown", "id": "daa2aa21-de67-44a9-a0ff-1a913e425ccc", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "03847b87-d096-49a5-b6e2-023fa08b94c2", "metadata": {}, "source": [ "# Load data" ] }, { "cell_type": "markdown", "id": "b3e902ea-4e0f-4d76-b27b-59e472b2b556", "metadata": {}, "source": [ "Загрузим данные для файнтюнинга — в частности, нам понадобятся названия статей и теги (абстрактов в этих данных нет)." ] }, { "cell_type": "code", "execution_count": 2, "id": "1be8f69e-bd7d-4ca9-ba9f-044b8e7bc497", "metadata": { "tags": [] }, "outputs": [], "source": [ "df = pd.read_csv(\"pubmed_landscape_data.csv\")" ] }, { "cell_type": "code", "execution_count": 62, "id": "ae78e0e8-a600-4607-8c1e-82ecdae17e2d", "metadata": { "tags": [] }, "outputs": [], "source": [ "df = df[df.Labels != \"unlabeled\"]\n", "df = df[~df.Title.isnull()]" ] }, { "cell_type": "code", "execution_count": 63, "id": "7715556f-8709-40cf-aa8c-3fecbfa3c1f4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(7123406, 10)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TitleJournalPMIDYearxyLabelsColorstextlabel
18Determination of some in vitro growth requirem...Journal of general microbiology11335741975.0-140.83026.596microbiology#B79762Determination of some in vitro growth requirem...microbiology
19Degradation of agar by a gram-negative bacterium.Journal of general microbiology11335751975.0-72.913-4.436microbiology#B79762Degradation of agar by a gram-negative bacterium.microbiology
20Choroid plexus isografts in rats.Journal of neuropathology and experimental neu...11335861975.0-46.56196.421neurology#009271Choroid plexus isografts in rats.neurology
29Preliminary report on a mass screening program...The Journal of pediatrics11336481975.045.03339.256pediatric#004D43Preliminary report on a mass screening program...pediatric
30Hepatic changes in young infants with cystic f...The Journal of pediatrics11336491975.0118.38061.870pediatric#004D43Hepatic changes in young infants with cystic f...pediatric
\n", "
" ], "text/plain": [ " Title \\\n", "18 Determination of some in vitro growth requirem... \n", "19 Degradation of agar by a gram-negative bacterium. \n", "20 Choroid plexus isografts in rats. \n", "29 Preliminary report on a mass screening program... \n", "30 Hepatic changes in young infants with cystic f... \n", "\n", " Journal PMID Year \\\n", "18 Journal of general microbiology 1133574 1975.0 \n", "19 Journal of general microbiology 1133575 1975.0 \n", "20 Journal of neuropathology and experimental neu... 1133586 1975.0 \n", "29 The Journal of pediatrics 1133648 1975.0 \n", "30 The Journal of pediatrics 1133649 1975.0 \n", "\n", " x y Labels Colors \\\n", "18 -140.830 26.596 microbiology #B79762 \n", "19 -72.913 -4.436 microbiology #B79762 \n", "20 -46.561 96.421 neurology #009271 \n", "29 45.033 39.256 pediatric #004D43 \n", "30 118.380 61.870 pediatric #004D43 \n", "\n", " text label \n", "18 Determination of some in vitro growth requirem... microbiology \n", "19 Degradation of agar by a gram-negative bacterium. microbiology \n", "20 Choroid plexus isografts in rats. neurology \n", "29 Preliminary report on a mass screening program... pediatric \n", "30 Hepatic changes in young infants with cystic f... pediatric " ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(df.shape)\n", "df.head(5)" ] }, { "cell_type": "markdown", "id": "791edb3c-a96d-4042-b35d-c8097bbbef79", "metadata": {}, "source": [ " " ] }, { "cell_type": "code", "execution_count": 76, "id": "81bff36c-0844-49c8-a4e8-162bb1233a45", "metadata": { "tags": [] }, "outputs": [], "source": [ "df.columns = ['text', 'journal', 'pmid', 'year', 'x', 'y', 'label', 'color'] # no abstract in this dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "c187efce-212b-494b-9157-0e8ceb1a2f3c", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Use subset of the data for faster training\n", "df = df.head(1_000_000)" ] }, { "cell_type": "markdown", "id": "68fd806d-ba31-4769-9d57-2762710a6fb7", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "ce1de806-a4d2-4e58-a3a8-f3542392f22e", "metadata": {}, "source": [ "## Labels" ] }, { "cell_type": "markdown", "id": "b5183517-8b02-47bc-812a-415b5651e07d", "metadata": {}, "source": [ "Будем использовать размеченные лейблы для статей:" ] }, { "cell_type": "code", "execution_count": 72, "id": "ba4e7197-23b6-4cb4-9b44-620c6b730eb7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total: 38 labels such as anesthesiology, biochemistry, ..., virology\n" ] } ], "source": [ "categories = np.unique(df['label'])\n", "num_labels = len(categories)\n", "print(f\"Total: {num_labels} labels such as {categories[0]}, {categories[1]}, ..., {categories[-1]}\")" ] }, { "cell_type": "markdown", "id": "10b49edd-0929-47e7-bb77-bc71528eb726", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "76d8ccb9-a993-4d82-9dd3-689380e92e55", "metadata": {}, "source": [ "# Model" ] }, { "cell_type": "code", "execution_count": 11, "id": "a0c154f7-d2fa-46a1-8b69-57174bf00632", "metadata": { "tags": [] }, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "id": "2bf6513d-664d-4b94-8b05-7e8df205e3ec", "metadata": {}, "source": [ "Токенайзер (название + абстракт -> токены):" ] }, { "cell_type": "code", "execution_count": 12, "id": "12fa49a7-2ac5-4f78-84fe-93305926692e", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")" ] }, { "cell_type": "markdown", "id": "0ea1b4e5-9067-4292-ba12-8f560bbf26fd", "metadata": {}, "source": [ "Сама модель, в которой `AutoModelForSequenceClassification` заменит голову для задачи классификации:" ] }, { "cell_type": "code", "execution_count": 13, "id": "d6eb92bc-c293-47ad-b9cc-2a63e8f1de69", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", num_labels=num_labels).to(device)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f5c79846-e6fc-42c0-bb8d-949678f5e60a", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BertForSequenceClassification(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=38, bias=True)\n", ")\n" ] } ], "source": [ "print(model)" ] }, { "cell_type": "markdown", "id": "4ce5d616-c9d6-47e5-afa4-74a95727d2e5", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "5ce6eefc-91ce-4486-9568-b686d04adcc7", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "markdown", "id": "71add72c-eafb-491a-8820-31ce7336524f", "metadata": {}, "source": [ "## Data Loaders" ] }, { "cell_type": "markdown", "id": "2a0b579c-998a-4d2e-bf0e-d4c7406d22da", "metadata": {}, "source": [ "Для работы с `transformers`, возможно, будет удобнее использовать библиотеку `datasets` для работы с данными." ] }, { "cell_type": "markdown", "id": "47b0e14a-866b-49ac-8b95-49a91a0bcc22", "metadata": {}, "source": [ "Создадим (hugging face) [датасет](https://huggingface.co/docs/datasets/tabular_load#pandas-dataframes):" ] }, { "cell_type": "code", "execution_count": 84, "id": "dc1a3f33-0ef9-43c9-ab5f-eb9ae304b897", "metadata": { "tags": [] }, "outputs": [], "source": [ "np.random.seed(42)\n", "is_train = np.random.binomial(1, .9, size=len(df))\n", "train_indices = np.arange(len(df))[is_train.astype(bool)]\n", "test_indices = np.arange(len(df))[(1 - is_train).astype(bool)]" ] }, { "cell_type": "code", "execution_count": 85, "id": "d948f8a6-1a7a-4baa-88a0-418596a1f275", "metadata": { "tags": [] }, "outputs": [], "source": [ "train_df = df.loc[:,[\"text\", \"label\"]].iloc[train_indices]\n", "test_df = df.loc[:,[\"text\", \"label\"]].iloc[test_indices]\n", "\n", "train_ds = Dataset.from_pandas(train_df, split=\"train\")\n", "test_ds = Dataset.from_pandas(test_df, split=\"test\")" ] }, { "cell_type": "code", "execution_count": 86, "id": "50242a35-3067-41e5-8de8-f7e6a4fb6e9c", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/63085 [00:00 main\n", "\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[148], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/transformers/trainer.py:3661\u001b[0m, in \u001b[0;36mTrainer.push_to_hub\u001b[0;34m(self, commit_message, blocking, **kwargs)\u001b[0m\n\u001b[1;32m 3658\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpush_in_progress\u001b[38;5;241m.\u001b[39m_process\u001b[38;5;241m.\u001b[39mkill()\n\u001b[1;32m 3659\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpush_in_progress \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3661\u001b[0m git_head_commit_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrepo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3662\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_message\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mblocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mauto_lfs_prune\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 3663\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3664\u001b[0m \u001b[38;5;66;03m# push separately the model card to be independant from the rest of the model\u001b[39;00m\n\u001b[1;32m 3665\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/huggingface_hub/repository.py:1307\u001b[0m, in \u001b[0;36mRepository.push_to_hub\u001b[0;34m(self, commit_message, blocking, clean_ok, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1305\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgit_add(auto_lfs_track\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 1306\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgit_commit(commit_message)\n\u001b[0;32m-> 1307\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgit_push\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1308\u001b[0m \u001b[43m \u001b[49m\u001b[43mupstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43morigin \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_branch\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1309\u001b[0m \u001b[43m \u001b[49m\u001b[43mblocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mblocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1310\u001b[0m \u001b[43m \u001b[49m\u001b[43mauto_lfs_prune\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mauto_lfs_prune\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1311\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/huggingface_hub/repository.py:1099\u001b[0m, in \u001b[0;36mRepository.git_push\u001b[0;34m(self, upstream, blocking, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1096\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(stderr)\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_code:\n\u001b[0;32m-> 1099\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m subprocess\u001b[38;5;241m.\u001b[39mCalledProcessError(return_code, process\u001b[38;5;241m.\u001b[39margs, output\u001b[38;5;241m=\u001b[39mstdout, stderr\u001b[38;5;241m=\u001b[39mstderr)\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m subprocess\u001b[38;5;241m.\u001b[39mCalledProcessError \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 1102\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(exc\u001b[38;5;241m.\u001b[39mstderr)\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/contextlib.py:126\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__exit__\u001b[0;34m(self, typ, value, traceback)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m typ \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 126\u001b[0m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/huggingface_hub/repository.py:420\u001b[0m, in \u001b[0;36m_lfs_log_progress\u001b[0;34m()\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 419\u001b[0m exit_event\u001b[38;5;241m.\u001b[39mset()\n\u001b[0;32m--> 420\u001b[0m \u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 422\u001b[0m os\u001b[38;5;241m.\u001b[39menviron[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGIT_LFS_PROGRESS\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m current_lfs_progress_value\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/threading.py:1060\u001b[0m, in \u001b[0;36mThread.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot join current thread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1059\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1060\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1062\u001b[0m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[1;32m 1063\u001b[0m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_tstate_lock(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[38;5;241m0\u001b[39m))\n", "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/threading.py:1080\u001b[0m, in \u001b[0;36mThread._wait_for_tstate_lock\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 1077\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1079\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1080\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 1081\u001b[0m lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 1082\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stop()\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "trainer.push_to_hub()" ] }, { "cell_type": "markdown", "id": "00c6f38c-2efa-45fc-9624-6df2f92b1cbd", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "b1a1029f-543c-409e-9aaf-35bcefe49988", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "markdown", "id": "e7b0cd5a-2e17-49f3-b2a9-5ae4e8511969", "metadata": {}, "source": [ "Теперь попробуем загрузить модель с HF Hub:" ] }, { "cell_type": "code", "execution_count": 2, "id": "b7fe37b9-61a9-4796-af24-092f6722cd61", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a6713aaa55ee41659ce0622caf61342c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading pytorch_model.bin: 0%| | 0.00/438M [00:00= threshold:\n", " break\n", "\n", " preds = preds[:(i+1)]\n", " \n", " return preds" ] }, { "cell_type": "code", "execution_count": 5, "id": "4ff5fc57-b3a8-409f-a128-5cf8ed75ca01", "metadata": { "tags": [] }, "outputs": [], "source": [ "def format_predictions(preds) -> str:\n", " \"\"\"\n", " Prepare predictions and their scores for printing to the user\n", " \"\"\"\n", " out = \"\"\n", " for i, item in enumerate(preds):\n", " out += f\"{i+1}. {item['label']} (score {item['score']:.2f})\\n\"\n", " return out" ] }, { "cell_type": "markdown", "id": "824a971a-de90-423b-919e-5d6deff29b27", "metadata": {}, "source": [ "Возьмём [статью](https://www.nature.com/articles/515180a) для примера:" ] }, { "cell_type": "code", "execution_count": 6, "id": "ebb07796-ef9c-41e7-ad6f-7ea236e0c25b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1. psychiatry (score 0.97)\n", "\n" ] } ], "source": [ "print(\n", " format_predictions(\n", " top_pct(\n", " pipe(\"\"\"\n", "Mental health: A world of depression\n", "Depression is a major human blight. Globally, it is responsible for more ‘years lost’ to disability than any other condition. This is largely because so many people suffer from it — some 350 million, according to the World Health Organization — and the fact that it lasts for many years. (When ranked by disability and death combined, depression comes ninth behind prolific killers such as heart disease, stroke and HIV.) Yet depression is widely undiagnosed and untreated because of stigma, lack of effective therapies and inadequate mental-health resources. Almost half of the world’s population lives in a country with only two psychiatrists per 100,000 people.\n", "\"\"\"\n", " )[0]\n", " )\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "459169e0-75e2-4003-8766-8f588fcb0a27", "metadata": {}, "source": [ " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }