oracat commited on
Commit
651fb4b
1 Parent(s): 6f64518

Upload finetuning-pubmed.ipynb

Browse files
Files changed (1) hide show
  1. finetuning-pubmed.ipynb +1301 -0
finetuning-pubmed.ipynb ADDED
@@ -0,0 +1,1301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "1c71aba7-c0f3-4378-9b63-55529e0994b4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Data\n",
9
+ "\n",
10
+ "Мы используем следующий датасет для файнтюнинга:\n",
11
+ "\n",
12
+ "- [датасет](https://zenodo.org/record/7695390) из [недавнего исследования](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1) с названиями и лейблами статей из PubMed. \n",
13
+ "\n",
14
+ "В нём 20 миллионов статей, но приведены только заголовки (без абстрактов — их можно дополнительно [получить](https://www.nlm.nih.gov/databases/download/pubmed_medline.html) по PMID статей). Файнтюнинг модели на таком объёме данных потребует определённых времени и вычислительных ресурсов (примерные затраты [приведены в статье](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1)), поэтому ниже мы воспользуемся упрощённым датасетом и будем тренировать только на заголовках статей."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "id": "e9874f4a-3898-4c89-a0f7-04eeabf2b389",
20
+ "metadata": {
21
+ "tags": []
22
+ },
23
+ "source": [
24
+ "# Models\n",
25
+ "\n",
26
+ "В качестве базовой модели мы используем BERT, натренированный на биомедицинских данных (из PubMed). \n",
27
+ "\n",
28
+ "- [BiomedNLP-PubMedBERT](https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "id": "991e48e7-897f-45a3-8a0b-539ea67b4eb5",
34
+ "metadata": {},
35
+ "source": [
36
+ "---"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "id": "2f130f05-21ee-46f9-889f-488e8c676aba",
42
+ "metadata": {},
43
+ "source": [
44
+ "# Imports"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 1,
50
+ "id": "757a0582-1b8c-4f1c-b26f-544688e391f4",
51
+ "metadata": {
52
+ "tags": []
53
+ },
54
+ "outputs": [],
55
+ "source": [
56
+ "import torch\n",
57
+ "import transformers\n",
58
+ "import numpy as np\n",
59
+ "import pandas as pd\n",
60
+ "from tqdm import tqdm\n",
61
+ "\n",
62
+ "import torch\n",
63
+ "from datasets import Dataset, ClassLabel\n",
64
+ "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification\n",
65
+ "from transformers import TrainingArguments, Trainer\n",
66
+ "from transformers import pipeline\n",
67
+ "import evaluate"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "id": "daa2aa21-de67-44a9-a0ff-1a913e425ccc",
73
+ "metadata": {},
74
+ "source": [
75
+ " "
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "id": "03847b87-d096-49a5-b6e2-023fa08b94c2",
81
+ "metadata": {},
82
+ "source": [
83
+ "# Load data"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "id": "b3e902ea-4e0f-4d76-b27b-59e472b2b556",
89
+ "metadata": {},
90
+ "source": [
91
+ "Загрузим данные для файнтюнинга — в частности, нам понадобятся названия статей и теги (абстрактов в этих данных нет)."
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 2,
97
+ "id": "1be8f69e-bd7d-4ca9-ba9f-044b8e7bc497",
98
+ "metadata": {
99
+ "tags": []
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "df = pd.read_csv(\"pubmed_landscape_data.csv\")"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 62,
109
+ "id": "ae78e0e8-a600-4607-8c1e-82ecdae17e2d",
110
+ "metadata": {
111
+ "tags": []
112
+ },
113
+ "outputs": [],
114
+ "source": [
115
+ "df = df[df.Labels != \"unlabeled\"]\n",
116
+ "df = df[~df.Title.isnull()]"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 63,
122
+ "id": "7715556f-8709-40cf-aa8c-3fecbfa3c1f4",
123
+ "metadata": {
124
+ "tags": []
125
+ },
126
+ "outputs": [
127
+ {
128
+ "name": "stdout",
129
+ "output_type": "stream",
130
+ "text": [
131
+ "(7123406, 10)\n"
132
+ ]
133
+ },
134
+ {
135
+ "data": {
136
+ "text/html": [
137
+ "<div>\n",
138
+ "<style scoped>\n",
139
+ " .dataframe tbody tr th:only-of-type {\n",
140
+ " vertical-align: middle;\n",
141
+ " }\n",
142
+ "\n",
143
+ " .dataframe tbody tr th {\n",
144
+ " vertical-align: top;\n",
145
+ " }\n",
146
+ "\n",
147
+ " .dataframe thead th {\n",
148
+ " text-align: right;\n",
149
+ " }\n",
150
+ "</style>\n",
151
+ "<table border=\"1\" class=\"dataframe\">\n",
152
+ " <thead>\n",
153
+ " <tr style=\"text-align: right;\">\n",
154
+ " <th></th>\n",
155
+ " <th>Title</th>\n",
156
+ " <th>Journal</th>\n",
157
+ " <th>PMID</th>\n",
158
+ " <th>Year</th>\n",
159
+ " <th>x</th>\n",
160
+ " <th>y</th>\n",
161
+ " <th>Labels</th>\n",
162
+ " <th>Colors</th>\n",
163
+ " <th>text</th>\n",
164
+ " <th>label</th>\n",
165
+ " </tr>\n",
166
+ " </thead>\n",
167
+ " <tbody>\n",
168
+ " <tr>\n",
169
+ " <th>18</th>\n",
170
+ " <td>Determination of some in vitro growth requirem...</td>\n",
171
+ " <td>Journal of general microbiology</td>\n",
172
+ " <td>1133574</td>\n",
173
+ " <td>1975.0</td>\n",
174
+ " <td>-140.830</td>\n",
175
+ " <td>26.596</td>\n",
176
+ " <td>microbiology</td>\n",
177
+ " <td>#B79762</td>\n",
178
+ " <td>Determination of some in vitro growth requirem...</td>\n",
179
+ " <td>microbiology</td>\n",
180
+ " </tr>\n",
181
+ " <tr>\n",
182
+ " <th>19</th>\n",
183
+ " <td>Degradation of agar by a gram-negative bacterium.</td>\n",
184
+ " <td>Journal of general microbiology</td>\n",
185
+ " <td>1133575</td>\n",
186
+ " <td>1975.0</td>\n",
187
+ " <td>-72.913</td>\n",
188
+ " <td>-4.436</td>\n",
189
+ " <td>microbiology</td>\n",
190
+ " <td>#B79762</td>\n",
191
+ " <td>Degradation of agar by a gram-negative bacterium.</td>\n",
192
+ " <td>microbiology</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <th>20</th>\n",
196
+ " <td>Choroid plexus isografts in rats.</td>\n",
197
+ " <td>Journal of neuropathology and experimental neu...</td>\n",
198
+ " <td>1133586</td>\n",
199
+ " <td>1975.0</td>\n",
200
+ " <td>-46.561</td>\n",
201
+ " <td>96.421</td>\n",
202
+ " <td>neurology</td>\n",
203
+ " <td>#009271</td>\n",
204
+ " <td>Choroid plexus isografts in rats.</td>\n",
205
+ " <td>neurology</td>\n",
206
+ " </tr>\n",
207
+ " <tr>\n",
208
+ " <th>29</th>\n",
209
+ " <td>Preliminary report on a mass screening program...</td>\n",
210
+ " <td>The Journal of pediatrics</td>\n",
211
+ " <td>1133648</td>\n",
212
+ " <td>1975.0</td>\n",
213
+ " <td>45.033</td>\n",
214
+ " <td>39.256</td>\n",
215
+ " <td>pediatric</td>\n",
216
+ " <td>#004D43</td>\n",
217
+ " <td>Preliminary report on a mass screening program...</td>\n",
218
+ " <td>pediatric</td>\n",
219
+ " </tr>\n",
220
+ " <tr>\n",
221
+ " <th>30</th>\n",
222
+ " <td>Hepatic changes in young infants with cystic f...</td>\n",
223
+ " <td>The Journal of pediatrics</td>\n",
224
+ " <td>1133649</td>\n",
225
+ " <td>1975.0</td>\n",
226
+ " <td>118.380</td>\n",
227
+ " <td>61.870</td>\n",
228
+ " <td>pediatric</td>\n",
229
+ " <td>#004D43</td>\n",
230
+ " <td>Hepatic changes in young infants with cystic f...</td>\n",
231
+ " <td>pediatric</td>\n",
232
+ " </tr>\n",
233
+ " </tbody>\n",
234
+ "</table>\n",
235
+ "</div>"
236
+ ],
237
+ "text/plain": [
238
+ " Title \\\n",
239
+ "18 Determination of some in vitro growth requirem... \n",
240
+ "19 Degradation of agar by a gram-negative bacterium. \n",
241
+ "20 Choroid plexus isografts in rats. \n",
242
+ "29 Preliminary report on a mass screening program... \n",
243
+ "30 Hepatic changes in young infants with cystic f... \n",
244
+ "\n",
245
+ " Journal PMID Year \\\n",
246
+ "18 Journal of general microbiology 1133574 1975.0 \n",
247
+ "19 Journal of general microbiology 1133575 1975.0 \n",
248
+ "20 Journal of neuropathology and experimental neu... 1133586 1975.0 \n",
249
+ "29 The Journal of pediatrics 1133648 1975.0 \n",
250
+ "30 The Journal of pediatrics 1133649 1975.0 \n",
251
+ "\n",
252
+ " x y Labels Colors \\\n",
253
+ "18 -140.830 26.596 microbiology #B79762 \n",
254
+ "19 -72.913 -4.436 microbiology #B79762 \n",
255
+ "20 -46.561 96.421 neurology #009271 \n",
256
+ "29 45.033 39.256 pediatric #004D43 \n",
257
+ "30 118.380 61.870 pediatric #004D43 \n",
258
+ "\n",
259
+ " text label \n",
260
+ "18 Determination of some in vitro growth requirem... microbiology \n",
261
+ "19 Degradation of agar by a gram-negative bacterium. microbiology \n",
262
+ "20 Choroid plexus isografts in rats. neurology \n",
263
+ "29 Preliminary report on a mass screening program... pediatric \n",
264
+ "30 Hepatic changes in young infants with cystic f... pediatric "
265
+ ]
266
+ },
267
+ "execution_count": 63,
268
+ "metadata": {},
269
+ "output_type": "execute_result"
270
+ }
271
+ ],
272
+ "source": [
273
+ "print(df.shape)\n",
274
+ "df.head(5)"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "id": "791edb3c-a96d-4042-b35d-c8097bbbef79",
280
+ "metadata": {},
281
+ "source": [
282
+ " "
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 76,
288
+ "id": "81bff36c-0844-49c8-a4e8-162bb1233a45",
289
+ "metadata": {
290
+ "tags": []
291
+ },
292
+ "outputs": [],
293
+ "source": [
294
+ "df.columns = ['text', 'journal', 'pmid', 'year', 'x', 'y', 'label', 'color'] # no abstract in this dataset"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "id": "c187efce-212b-494b-9157-0e8ceb1a2f3c",
301
+ "metadata": {
302
+ "tags": []
303
+ },
304
+ "outputs": [],
305
+ "source": [
306
+ "# Use subset of the data for faster training\n",
307
+ "df = df.head(1_000_000)"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "id": "68fd806d-ba31-4769-9d57-2762710a6fb7",
313
+ "metadata": {},
314
+ "source": [
315
+ " "
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "markdown",
320
+ "id": "ce1de806-a4d2-4e58-a3a8-f3542392f22e",
321
+ "metadata": {},
322
+ "source": [
323
+ "## Labels"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "id": "b5183517-8b02-47bc-812a-415b5651e07d",
329
+ "metadata": {},
330
+ "source": [
331
+ "Будем использовать размеченные лейблы для статей:"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": 72,
337
+ "id": "ba4e7197-23b6-4cb4-9b44-620c6b730eb7",
338
+ "metadata": {
339
+ "tags": []
340
+ },
341
+ "outputs": [
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "Total: 38 labels such as anesthesiology, biochemistry, ..., virology\n"
347
+ ]
348
+ }
349
+ ],
350
+ "source": [
351
+ "categories = np.unique(df['label'])\n",
352
+ "num_labels = len(categories)\n",
353
+ "print(f\"Total: {num_labels} labels such as {categories[0]}, {categories[1]}, ..., {categories[-1]}\")"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "markdown",
358
+ "id": "10b49edd-0929-47e7-bb77-bc71528eb726",
359
+ "metadata": {},
360
+ "source": [
361
+ " "
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "id": "76d8ccb9-a993-4d82-9dd3-689380e92e55",
367
+ "metadata": {},
368
+ "source": [
369
+ "# Model"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": 11,
375
+ "id": "a0c154f7-d2fa-46a1-8b69-57174bf00632",
376
+ "metadata": {
377
+ "tags": []
378
+ },
379
+ "outputs": [],
380
+ "source": [
381
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "markdown",
386
+ "id": "2bf6513d-664d-4b94-8b05-7e8df205e3ec",
387
+ "metadata": {},
388
+ "source": [
389
+ "Токенайзер (название + абстракт -> токены):"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 12,
395
+ "id": "12fa49a7-2ac5-4f78-84fe-93305926692e",
396
+ "metadata": {
397
+ "tags": []
398
+ },
399
+ "outputs": [],
400
+ "source": [
401
+ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "markdown",
406
+ "id": "0ea1b4e5-9067-4292-ba12-8f560bbf26fd",
407
+ "metadata": {},
408
+ "source": [
409
+ "Сама модель, в которой `AutoModelForSequenceClassification` заменит голову для задачи классификации:"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": 13,
415
+ "id": "d6eb92bc-c293-47ad-b9cc-2a63e8f1de69",
416
+ "metadata": {
417
+ "tags": []
418
+ },
419
+ "outputs": [
420
+ {
421
+ "name": "stderr",
422
+ "output_type": "stream",
423
+ "text": [
424
+ "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
425
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
426
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
427
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
428
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
429
+ ]
430
+ }
431
+ ],
432
+ "source": [
433
+ "model = AutoModelForSequenceClassification.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", num_labels=num_labels).to(device)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": 14,
439
+ "id": "f5c79846-e6fc-42c0-bb8d-949678f5e60a",
440
+ "metadata": {
441
+ "scrolled": true,
442
+ "tags": []
443
+ },
444
+ "outputs": [
445
+ {
446
+ "name": "stdout",
447
+ "output_type": "stream",
448
+ "text": [
449
+ "BertForSequenceClassification(\n",
450
+ " (bert): BertModel(\n",
451
+ " (embeddings): BertEmbeddings(\n",
452
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
453
+ " (position_embeddings): Embedding(512, 768)\n",
454
+ " (token_type_embeddings): Embedding(2, 768)\n",
455
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
456
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
457
+ " )\n",
458
+ " (encoder): BertEncoder(\n",
459
+ " (layer): ModuleList(\n",
460
+ " (0-11): 12 x BertLayer(\n",
461
+ " (attention): BertAttention(\n",
462
+ " (self): BertSelfAttention(\n",
463
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
464
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
465
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
466
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
467
+ " )\n",
468
+ " (output): BertSelfOutput(\n",
469
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
470
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
471
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
472
+ " )\n",
473
+ " )\n",
474
+ " (intermediate): BertIntermediate(\n",
475
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
476
+ " (intermediate_act_fn): GELUActivation()\n",
477
+ " )\n",
478
+ " (output): BertOutput(\n",
479
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
480
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
481
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
482
+ " )\n",
483
+ " )\n",
484
+ " )\n",
485
+ " )\n",
486
+ " (pooler): BertPooler(\n",
487
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
488
+ " (activation): Tanh()\n",
489
+ " )\n",
490
+ " )\n",
491
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
492
+ " (classifier): Linear(in_features=768, out_features=38, bias=True)\n",
493
+ ")\n"
494
+ ]
495
+ }
496
+ ],
497
+ "source": [
498
+ "print(model)"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "markdown",
503
+ "id": "4ce5d616-c9d6-47e5-afa4-74a95727d2e5",
504
+ "metadata": {},
505
+ "source": [
506
+ " "
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "markdown",
511
+ "id": "5ce6eefc-91ce-4486-9568-b686d04adcc7",
512
+ "metadata": {},
513
+ "source": [
514
+ "# Training"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "markdown",
519
+ "id": "71add72c-eafb-491a-8820-31ce7336524f",
520
+ "metadata": {},
521
+ "source": [
522
+ "## Data Loaders"
523
+ ]
524
+ },
525
+ {
526
+ "cell_type": "markdown",
527
+ "id": "2a0b579c-998a-4d2e-bf0e-d4c7406d22da",
528
+ "metadata": {},
529
+ "source": [
530
+ "Для работы с `transformers`, возможно, будет удобнее использовать библиотеку `datasets` для работы с данными."
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "markdown",
535
+ "id": "47b0e14a-866b-49ac-8b95-49a91a0bcc22",
536
+ "metadata": {},
537
+ "source": [
538
+ "Создадим (hugging face) [датасет](https://huggingface.co/docs/datasets/tabular_load#pandas-dataframes):"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": 84,
544
+ "id": "dc1a3f33-0ef9-43c9-ab5f-eb9ae304b897",
545
+ "metadata": {
546
+ "tags": []
547
+ },
548
+ "outputs": [],
549
+ "source": [
550
+ "np.random.seed(42)\n",
551
+ "is_train = np.random.binomial(1, .9, size=len(df))\n",
552
+ "train_indices = np.arange(len(df))[is_train.astype(bool)]\n",
553
+ "test_indices = np.arange(len(df))[(1 - is_train).astype(bool)]"
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "execution_count": 85,
559
+ "id": "d948f8a6-1a7a-4baa-88a0-418596a1f275",
560
+ "metadata": {
561
+ "tags": []
562
+ },
563
+ "outputs": [],
564
+ "source": [
565
+ "train_df = df.loc[:,[\"text\", \"label\"]].iloc[train_indices]\n",
566
+ "test_df = df.loc[:,[\"text\", \"label\"]].iloc[test_indices]\n",
567
+ "\n",
568
+ "train_ds = Dataset.from_pandas(train_df, split=\"train\")\n",
569
+ "test_ds = Dataset.from_pandas(test_df, split=\"test\")"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": 86,
575
+ "id": "50242a35-3067-41e5-8de8-f7e6a4fb6e9c",
576
+ "metadata": {
577
+ "tags": []
578
+ },
579
+ "outputs": [
580
+ {
581
+ "data": {
582
+ "application/vnd.jupyter.widget-view+json": {
583
+ "model_id": "",
584
+ "version_major": 2,
585
+ "version_minor": 0
586
+ },
587
+ "text/plain": [
588
+ "Map: 0%| | 0/63085 [00:00<?, ? examples/s]"
589
+ ]
590
+ },
591
+ "metadata": {},
592
+ "output_type": "display_data"
593
+ },
594
+ {
595
+ "data": {
596
+ "application/vnd.jupyter.widget-view+json": {
597
+ "model_id": "",
598
+ "version_major": 2,
599
+ "version_minor": 0
600
+ },
601
+ "text/plain": [
602
+ "Map: 0%| | 0/6915 [00:00<?, ? examples/s]"
603
+ ]
604
+ },
605
+ "metadata": {},
606
+ "output_type": "display_data"
607
+ }
608
+ ],
609
+ "source": [
610
+ "def tokenize_text(row):\n",
611
+ " return tokenizer(\n",
612
+ " row[\"text\"],\n",
613
+ " max_length=512,\n",
614
+ " truncation=True,\n",
615
+ " padding='max_length',\n",
616
+ " )\n",
617
+ "\n",
618
+ "train_ds = train_ds.map(tokenize_text, batched=True)\n",
619
+ "test_ds = test_ds.map(tokenize_text, batched=True)"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "markdown",
624
+ "id": "08a306da-1f66-4b84-8b8e-7152c8928b0f",
625
+ "metadata": {},
626
+ "source": [
627
+ "(Уже этот шаг на таком объёме данных может занять около часа...)"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "markdown",
632
+ "id": "9febd884-681c-42e1-af1c-8192b448358b",
633
+ "metadata": {},
634
+ "source": [
635
+ " "
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "code",
640
+ "execution_count": 87,
641
+ "id": "35d454d1-fbdc-4847-8b60-4c6c442364b1",
642
+ "metadata": {
643
+ "tags": []
644
+ },
645
+ "outputs": [
646
+ {
647
+ "data": {
648
+ "application/vnd.jupyter.widget-view+json": {
649
+ "model_id": "",
650
+ "version_major": 2,
651
+ "version_minor": 0
652
+ },
653
+ "text/plain": [
654
+ "Map: 0%| | 0/63085 [00:00<?, ? examples/s]"
655
+ ]
656
+ },
657
+ "metadata": {},
658
+ "output_type": "display_data"
659
+ },
660
+ {
661
+ "data": {
662
+ "application/vnd.jupyter.widget-view+json": {
663
+ "model_id": "",
664
+ "version_major": 2,
665
+ "version_minor": 0
666
+ },
667
+ "text/plain": [
668
+ "Map: 0%| | 0/6915 [00:00<?, ? examples/s]"
669
+ ]
670
+ },
671
+ "metadata": {},
672
+ "output_type": "display_data"
673
+ },
674
+ {
675
+ "data": {
676
+ "application/vnd.jupyter.widget-view+json": {
677
+ "model_id": "",
678
+ "version_major": 2,
679
+ "version_minor": 0
680
+ },
681
+ "text/plain": [
682
+ "Casting the dataset: 0%| | 0/63085 [00:00<?, ? examples/s]"
683
+ ]
684
+ },
685
+ "metadata": {},
686
+ "output_type": "display_data"
687
+ },
688
+ {
689
+ "data": {
690
+ "application/vnd.jupyter.widget-view+json": {
691
+ "model_id": "",
692
+ "version_major": 2,
693
+ "version_minor": 0
694
+ },
695
+ "text/plain": [
696
+ "Casting the dataset: 0%| | 0/6915 [00:00<?, ? examples/s]"
697
+ ]
698
+ },
699
+ "metadata": {},
700
+ "output_type": "display_data"
701
+ }
702
+ ],
703
+ "source": [
704
+ "labels_map = ClassLabel(num_classes=num_labels, names=list(categories))\n",
705
+ "\n",
706
+ "def transform_labels(row):\n",
707
+ " # default name for a label (label or label_ids)\n",
708
+ " return {\"label\": labels_map.str2int(row[\"label\"])}\n",
709
+ "\n",
710
+ "# OR: \n",
711
+ "# \n",
712
+ "# labels_map = pd.Series(\n",
713
+ "# np.arange(num_labels),\n",
714
+ "# index=categories,\n",
715
+ "# )\n",
716
+ "# \n",
717
+ "# def transform_labels(row):\n",
718
+ "# return {\"label\": labels_map[row[\"category\"]]}\n",
719
+ "\n",
720
+ "train_ds = train_ds.map(transform_labels, batched=True)\n",
721
+ "test_ds = test_ds.map(transform_labels, batched=True)\n",
722
+ "\n",
723
+ "train_ds = train_ds.cast_column('label', labels_map)\n",
724
+ "test_ds = test_ds.cast_column('label', labels_map)"
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "markdown",
729
+ "id": "1db1cc91-b744-4415-a4b1-b383102b792b",
730
+ "metadata": {},
731
+ "source": [
732
+ " "
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "markdown",
737
+ "id": "811c5fe3-218e-4187-878d-65abc157f802",
738
+ "metadata": {},
739
+ "source": [
740
+ "## Prepare training"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": 88,
746
+ "id": "d2160c7d-4130-47ae-9d6d-6684e4ba7e9b",
747
+ "metadata": {
748
+ "tags": []
749
+ },
750
+ "outputs": [
751
+ {
752
+ "name": "stderr",
753
+ "output_type": "stream",
754
+ "text": [
755
+ "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
756
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
757
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
758
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
759
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
760
+ ]
761
+ }
762
+ ],
763
+ "source": [
764
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
765
+ " \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", \n",
766
+ " num_labels=num_labels,\n",
767
+ " id2label={i:labels_map.names[i] for i in range(len(categories))},\n",
768
+ " label2id={labels_map.names[i]:i for i in range(len(categories))},\n",
769
+ ").to(device)"
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": 89,
775
+ "id": "72e74c2b-89d7-4c17-8df1-dcfd40ead01e",
776
+ "metadata": {
777
+ "tags": []
778
+ },
779
+ "outputs": [],
780
+ "source": [
781
+ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
782
+ ]
783
+ },
784
+ {
785
+ "cell_type": "markdown",
786
+ "id": "ebb91037-fbdf-4453-87de-6da5eec3304f",
787
+ "metadata": {},
788
+ "source": [
789
+ "Будем вычислять accuracy:"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "code",
794
+ "execution_count": 90,
795
+ "id": "630f6fa5-4c53-4962-b36d-5ee9aad6e29d",
796
+ "metadata": {
797
+ "tags": []
798
+ },
799
+ "outputs": [],
800
+ "source": [
801
+ "metric = evaluate.load(\"accuracy\")\n",
802
+ "\n",
803
+ "def compute_metrics(eval_pred):\n",
804
+ " logits, labels = eval_pred\n",
805
+ " predictions = np.argmax(logits, axis=-1)\n",
806
+ " return metric.compute(predictions=predictions, references=labels)"
807
+ ]
808
+ },
809
+ {
810
+ "cell_type": "markdown",
811
+ "id": "d448fd8d-4ce4-4a54-9931-037732ffc0a7",
812
+ "metadata": {},
813
+ "source": [
814
+ "Параметры для трейна:"
815
+ ]
816
+ },
817
+ {
818
+ "cell_type": "code",
819
+ "execution_count": 92,
820
+ "id": "f64425b7-72b7-466a-8e3e-cd7624893139",
821
+ "metadata": {
822
+ "tags": []
823
+ },
824
+ "outputs": [],
825
+ "source": [
826
+ "training_args = TrainingArguments(\n",
827
+ " output_dir=\"bert-paper-classifier\", \n",
828
+ " evaluation_strategy=\"epoch\",\n",
829
+ " per_device_train_batch_size=64,\n",
830
+ " num_train_epochs=3,\n",
831
+ " logging_steps=100,\n",
832
+ ")"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "markdown",
837
+ "id": "569b885e-ee94-4f4a-bc5a-5ff5df7d5aea",
838
+ "metadata": {},
839
+ "source": [
840
+ "## Training"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": 93,
846
+ "id": "b850cd9b-eb36-40ec-8cf2-26206fedcf27",
847
+ "metadata": {
848
+ "tags": []
849
+ },
850
+ "outputs": [],
851
+ "source": [
852
+ "trainer = Trainer(\n",
853
+ " model=model,\n",
854
+ " args=training_args,\n",
855
+ " train_dataset=train_ds,\n",
856
+ " eval_dataset=test_ds,\n",
857
+ " compute_metrics=compute_metrics,\n",
858
+ ")"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": null,
864
+ "id": "a50e9e26-922c-436f-a55b-eb9084a33b01",
865
+ "metadata": {},
866
+ "outputs": [],
867
+ "source": [
868
+ "trainer.train()\n",
869
+ "# Convert to a python file and run training:\n",
870
+ "#! jupyter nbconvert finetuning-pubmed.ipynb --to python"
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "markdown",
875
+ "id": "fa597101-e9a9-4e08-b2f9-eea7818a0eca",
876
+ "metadata": {},
877
+ "source": [
878
+ " "
879
+ ]
880
+ },
881
+ {
882
+ "cell_type": "markdown",
883
+ "id": "cc8dad7d-8105-4f37-9087-615314c35afb",
884
+ "metadata": {},
885
+ "source": [
886
+ "# Save and share"
887
+ ]
888
+ },
889
+ {
890
+ "cell_type": "code",
891
+ "execution_count": 96,
892
+ "id": "38d24722-d5c6-40ac-b568-3cd7fd9f225e",
893
+ "metadata": {
894
+ "tags": []
895
+ },
896
+ "outputs": [],
897
+ "source": [
898
+ "trainer.args.hub_model_id = \"bert-paper-classifier\""
899
+ ]
900
+ },
901
+ {
902
+ "cell_type": "code",
903
+ "execution_count": 145,
904
+ "id": "9530790c-bc63-48f4-9a01-8c534fa90e00",
905
+ "metadata": {
906
+ "tags": []
907
+ },
908
+ "outputs": [
909
+ {
910
+ "data": {
911
+ "text/plain": [
912
+ "('bert-paper-classifier/tokenizer_config.json',\n",
913
+ " 'bert-paper-classifier/special_tokens_map.json',\n",
914
+ " 'bert-paper-classifier/vocab.txt',\n",
915
+ " 'bert-paper-classifier/added_tokens.json',\n",
916
+ " 'bert-paper-classifier/tokenizer.json')"
917
+ ]
918
+ },
919
+ "execution_count": 145,
920
+ "metadata": {},
921
+ "output_type": "execute_result"
922
+ }
923
+ ],
924
+ "source": [
925
+ "tokenizer.save_pretrained(\"bert-paper-classifier\")"
926
+ ]
927
+ },
928
+ {
929
+ "cell_type": "code",
930
+ "execution_count": 146,
931
+ "id": "0498df97-cd2c-4732-9d07-ee2013f8bd55",
932
+ "metadata": {
933
+ "tags": []
934
+ },
935
+ "outputs": [],
936
+ "source": [
937
+ "trainer.save_model(\"bert-paper-classifier\")"
938
+ ]
939
+ },
940
+ {
941
+ "cell_type": "markdown",
942
+ "id": "7af12b9e-0d77-48ec-af6f-38556e13b067",
943
+ "metadata": {
944
+ "tags": []
945
+ },
946
+ "source": [
947
+ "Запушим модель на HF Hub:"
948
+ ]
949
+ },
950
+ {
951
+ "cell_type": "code",
952
+ "execution_count": 148,
953
+ "id": "5de0e91f-bc23-4413-b22e-5aa32b09ef12",
954
+ "metadata": {
955
+ "scrolled": true,
956
+ "tags": []
957
+ },
958
+ "outputs": [
959
+ {
960
+ "name": "stdout",
961
+ "output_type": "stream",
962
+ "text": [
963
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
964
+ "To disable this warning, you can either:\n",
965
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
966
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
967
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
968
+ "To disable this warning, you can either:\n",
969
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
970
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
971
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
972
+ "To disable this warning, you can either:\n",
973
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
974
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
975
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
976
+ "To disable this warning, you can either:\n",
977
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
978
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
979
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
980
+ "To disable this warning, you can either:\n",
981
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
982
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
983
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
984
+ "To disable this warning, you can either:\n",
985
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
986
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
987
+ ]
988
+ },
989
+ {
990
+ "name": "stderr",
991
+ "output_type": "stream",
992
+ "text": [
993
+ "/g/stegle/bredikhi/projects/ml2/transformers/bert-paper-classifier is already a clone of https://huggingface.co/oracat/bert-paper-classifier. Make sure you pull the latest changes with `repo.git_pull()`.\n"
994
+ ]
995
+ },
996
+ {
997
+ "name": "stdout",
998
+ "output_type": "stream",
999
+ "text": [
1000
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1001
+ "To disable this warning, you can either:\n",
1002
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1003
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1004
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1005
+ "To disable this warning, you can either:\n",
1006
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1007
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1008
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1009
+ "To disable this warning, you can either:\n",
1010
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1011
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1012
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1013
+ "To disable this warning, you can either:\n",
1014
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1015
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1016
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1017
+ "To disable this warning, you can either:\n",
1018
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1019
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1020
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1021
+ "To disable this warning, you can either:\n",
1022
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1023
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1024
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1025
+ "To disable this warning, you can either:\n",
1026
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1027
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1028
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1029
+ "To disable this warning, you can either:\n",
1030
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1031
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1032
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1033
+ "To disable this warning, you can either:\n",
1034
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1035
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1036
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1037
+ "To disable this warning, you can either:\n",
1038
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1039
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1040
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1041
+ "To disable this warning, you can either:\n",
1042
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1043
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1044
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1045
+ "To disable this warning, you can either:\n",
1046
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1047
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1048
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1049
+ "To disable this warning, you can either:\n",
1050
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1051
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1052
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1053
+ "To disable this warning, you can either:\n",
1054
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1055
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1056
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1057
+ "To disable this warning, you can either:\n",
1058
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1059
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1060
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1061
+ "To disable this warning, you can either:\n",
1062
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1063
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1064
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1065
+ "To disable this warning, you can either:\n",
1066
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1067
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1068
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1069
+ "To disable this warning, you can either:\n",
1070
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1071
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1072
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1073
+ "To disable this warning, you can either:\n",
1074
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1075
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1076
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1077
+ "To disable this warning, you can either:\n",
1078
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1079
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1080
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1081
+ "To disable this warning, you can either:\n",
1082
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1083
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1084
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1085
+ "To disable this warning, you can either:\n",
1086
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1087
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1088
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1089
+ "To disable this warning, you can either:\n",
1090
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1091
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
1092
+ ]
1093
+ },
1094
+ {
1095
+ "name": "stderr",
1096
+ "output_type": "stream",
1097
+ "text": [
1098
+ "To https://huggingface.co/oracat/bert-paper-classifier\n",
1099
+ " 862abb7..b95fd36 main -> main\n",
1100
+ "\n"
1101
+ ]
1102
+ },
1103
+ {
1104
+ "ename": "KeyboardInterrupt",
1105
+ "evalue": "",
1106
+ "output_type": "error",
1107
+ "traceback": [
1108
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1109
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1110
+ "Cell \u001b[0;32mIn[148], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
1111
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/transformers/trainer.py:3661\u001b[0m, in \u001b[0;36mTrainer.push_to_hub\u001b[0;34m(self, commit_message, blocking, **kwargs)\u001b[0m\n\u001b[1;32m 3658\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpush_in_progress\u001b[38;5;241m.\u001b[39m_process\u001b[38;5;241m.\u001b[39mkill()\n\u001b[1;32m 3659\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpush_in_progress \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3661\u001b[0m git_head_commit_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrepo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3662\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_message\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mblocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mauto_lfs_prune\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 3663\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3664\u001b[0m \u001b[38;5;66;03m# push separately the model card to be independant from the rest of the model\u001b[39;00m\n\u001b[1;32m 3665\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n",
1112
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/huggingface_hub/repository.py:1307\u001b[0m, in \u001b[0;36mRepository.push_to_hub\u001b[0;34m(self, commit_message, blocking, clean_ok, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1305\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgit_add(auto_lfs_track\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 1306\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgit_commit(commit_message)\n\u001b[0;32m-> 1307\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgit_push\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1308\u001b[0m \u001b[43m \u001b[49m\u001b[43mupstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43morigin \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_branch\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1309\u001b[0m \u001b[43m \u001b[49m\u001b[43mblocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mblocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1310\u001b[0m \u001b[43m \u001b[49m\u001b[43mauto_lfs_prune\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mauto_lfs_prune\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1311\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1113
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/huggingface_hub/repository.py:1099\u001b[0m, in \u001b[0;36mRepository.git_push\u001b[0;34m(self, upstream, blocking, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1096\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(stderr)\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_code:\n\u001b[0;32m-> 1099\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m subprocess\u001b[38;5;241m.\u001b[39mCalledProcessError(return_code, process\u001b[38;5;241m.\u001b[39margs, output\u001b[38;5;241m=\u001b[39mstdout, stderr\u001b[38;5;241m=\u001b[39mstderr)\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m subprocess\u001b[38;5;241m.\u001b[39mCalledProcessError \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 1102\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(exc\u001b[38;5;241m.\u001b[39mstderr)\n",
1114
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/contextlib.py:126\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__exit__\u001b[0;34m(self, typ, value, traceback)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m typ \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 126\u001b[0m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
1115
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/site-packages/huggingface_hub/repository.py:420\u001b[0m, in \u001b[0;36m_lfs_log_progress\u001b[0;34m()\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 419\u001b[0m exit_event\u001b[38;5;241m.\u001b[39mset()\n\u001b[0;32m--> 420\u001b[0m \u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 422\u001b[0m os\u001b[38;5;241m.\u001b[39menviron[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGIT_LFS_PROGRESS\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m current_lfs_progress_value\n",
1116
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/threading.py:1060\u001b[0m, in \u001b[0;36mThread.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot join current thread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1059\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1060\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1062\u001b[0m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[1;32m 1063\u001b[0m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_tstate_lock(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[38;5;241m0\u001b[39m))\n",
1117
+ "File \u001b[0;32m~/homedir/conda/envs/torch/lib/python3.9/threading.py:1080\u001b[0m, in \u001b[0;36mThread._wait_for_tstate_lock\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 1077\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1079\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1080\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 1081\u001b[0m lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 1082\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stop()\n",
1118
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1119
+ ]
1120
+ }
1121
+ ],
1122
+ "source": [
1123
+ "trainer.push_to_hub()"
1124
+ ]
1125
+ },
1126
+ {
1127
+ "cell_type": "markdown",
1128
+ "id": "00c6f38c-2efa-45fc-9624-6df2f92b1cbd",
1129
+ "metadata": {},
1130
+ "source": [
1131
+ " "
1132
+ ]
1133
+ },
1134
+ {
1135
+ "cell_type": "markdown",
1136
+ "id": "b1a1029f-543c-409e-9aaf-35bcefe49988",
1137
+ "metadata": {},
1138
+ "source": [
1139
+ "# Inference"
1140
+ ]
1141
+ },
1142
+ {
1143
+ "cell_type": "markdown",
1144
+ "id": "e7b0cd5a-2e17-49f3-b2a9-5ae4e8511969",
1145
+ "metadata": {},
1146
+ "source": [
1147
+ "Теперь попробуем загрузить модель с HF Hub:"
1148
+ ]
1149
+ },
1150
+ {
1151
+ "cell_type": "code",
1152
+ "execution_count": 2,
1153
+ "id": "b7fe37b9-61a9-4796-af24-092f6722cd61",
1154
+ "metadata": {
1155
+ "tags": []
1156
+ },
1157
+ "outputs": [
1158
+ {
1159
+ "data": {
1160
+ "application/vnd.jupyter.widget-view+json": {
1161
+ "model_id": "a6713aaa55ee41659ce0622caf61342c",
1162
+ "version_major": 2,
1163
+ "version_minor": 0
1164
+ },
1165
+ "text/plain": [
1166
+ "Downloading pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]"
1167
+ ]
1168
+ },
1169
+ "metadata": {},
1170
+ "output_type": "display_data"
1171
+ }
1172
+ ],
1173
+ "source": [
1174
+ "inference_tokenizer = AutoTokenizer.from_pretrained(\"oracat/bert-paper-classifier\")\n",
1175
+ "inference_model = AutoModelForSequenceClassification.from_pretrained(\"oracat/bert-paper-classifier\")"
1176
+ ]
1177
+ },
1178
+ {
1179
+ "cell_type": "code",
1180
+ "execution_count": 3,
1181
+ "id": "34495235-4dca-4635-b468-5b15647a6682",
1182
+ "metadata": {
1183
+ "tags": []
1184
+ },
1185
+ "outputs": [],
1186
+ "source": [
1187
+ "pipe = pipeline(\"text-classification\", model=inference_model, tokenizer=inference_tokenizer, top_k=None)"
1188
+ ]
1189
+ },
1190
+ {
1191
+ "cell_type": "code",
1192
+ "execution_count": 4,
1193
+ "id": "9f8ed2de-6354-4a5e-98c6-7a9b5ebb1276",
1194
+ "metadata": {
1195
+ "tags": []
1196
+ },
1197
+ "outputs": [],
1198
+ "source": [
1199
+ "def top_pct(preds, threshold=.95):\n",
1200
+ " preds = sorted(preds, key=lambda x: -x[\"score\"])\n",
1201
+ " \n",
1202
+ " cum_score = 0\n",
1203
+ " for i, item in enumerate(preds):\n",
1204
+ " cum_score += item[\"score\"]\n",
1205
+ " if cum_score >= threshold:\n",
1206
+ " break\n",
1207
+ "\n",
1208
+ " preds = preds[:(i+1)]\n",
1209
+ " \n",
1210
+ " return preds"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "execution_count": 5,
1216
+ "id": "4ff5fc57-b3a8-409f-a128-5cf8ed75ca01",
1217
+ "metadata": {
1218
+ "tags": []
1219
+ },
1220
+ "outputs": [],
1221
+ "source": [
1222
+ "def format_predictions(preds) -> str:\n",
1223
+ " \"\"\"\n",
1224
+ " Prepare predictions and their scores for printing to the user\n",
1225
+ " \"\"\"\n",
1226
+ " out = \"\"\n",
1227
+ " for i, item in enumerate(preds):\n",
1228
+ " out += f\"{i+1}. {item['label']} (score {item['score']:.2f})\\n\"\n",
1229
+ " return out"
1230
+ ]
1231
+ },
1232
+ {
1233
+ "cell_type": "markdown",
1234
+ "id": "824a971a-de90-423b-919e-5d6deff29b27",
1235
+ "metadata": {},
1236
+ "source": [
1237
+ "Возьмём [статью](https://www.nature.com/articles/515180a) для примера:"
1238
+ ]
1239
+ },
1240
+ {
1241
+ "cell_type": "code",
1242
+ "execution_count": 6,
1243
+ "id": "ebb07796-ef9c-41e7-ad6f-7ea236e0c25b",
1244
+ "metadata": {
1245
+ "tags": []
1246
+ },
1247
+ "outputs": [
1248
+ {
1249
+ "name": "stdout",
1250
+ "output_type": "stream",
1251
+ "text": [
1252
+ "1. psychiatry (score 0.97)\n",
1253
+ "\n"
1254
+ ]
1255
+ }
1256
+ ],
1257
+ "source": [
1258
+ "print(\n",
1259
+ " format_predictions(\n",
1260
+ " top_pct(\n",
1261
+ " pipe(\"\"\"\n",
1262
+ "Mental health: A world of depression\n",
1263
+ "Depression is a major human blight. Globally, it is responsible for more ‘years lost’ to disability than any other condition. This is largely because so many people suffer from it — some 350 million, according to the World Health Organization — and the fact that it lasts for many years. (When ranked by disability and death combined, depression comes ninth behind prolific killers such as heart disease, stroke and HIV.) Yet depression is widely undiagnosed and untreated because of stigma, lack of effective therapies and inadequate mental-health resources. Almost half of the world’s population lives in a country with only two psychiatrists per 100,000 people.\n",
1264
+ "\"\"\"\n",
1265
+ " )[0]\n",
1266
+ " )\n",
1267
+ " )\n",
1268
+ ")"
1269
+ ]
1270
+ },
1271
+ {
1272
+ "cell_type": "markdown",
1273
+ "id": "459169e0-75e2-4003-8766-8f588fcb0a27",
1274
+ "metadata": {},
1275
+ "source": [
1276
+ " "
1277
+ ]
1278
+ }
1279
+ ],
1280
+ "metadata": {
1281
+ "kernelspec": {
1282
+ "display_name": "Python 3 (ipykernel)",
1283
+ "language": "python",
1284
+ "name": "python3"
1285
+ },
1286
+ "language_info": {
1287
+ "codemirror_mode": {
1288
+ "name": "ipython",
1289
+ "version": 3
1290
+ },
1291
+ "file_extension": ".py",
1292
+ "mimetype": "text/x-python",
1293
+ "name": "python",
1294
+ "nbconvert_exporter": "python",
1295
+ "pygments_lexer": "ipython3",
1296
+ "version": "3.10.8"
1297
+ }
1298
+ },
1299
+ "nbformat": 4,
1300
+ "nbformat_minor": 5
1301
+ }