surdan commited on
Commit
162232a
1 Parent(s): 72d05b7

Upload Train_model.ipynb

Browse files
Files changed (1) hide show
  1. Train_model.ipynb +361 -0
Train_model.ipynb ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3ca08817",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install seqeval"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "c5958200",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# import torch\n",
21
+ "# torch.cuda.is_available(), torch.cuda.device_count()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "590c3f48",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import warnings\n",
32
+ "warnings.filterwarnings('ignore')\n",
33
+ "\n",
34
+ "import pickle\n",
35
+ "import numpy as np\n",
36
+ "import transformers\n",
37
+ "from transformers import Trainer\n",
38
+ "from datasets import load_metric\n",
39
+ "from datasets import load_dataset\n",
40
+ "from transformers import AutoTokenizer\n",
41
+ "from transformers import TrainingArguments\n",
42
+ "from transformers import AutoModelForTokenClassification\n",
43
+ "from transformers import DataCollatorForTokenClassification"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "44d7c35c",
49
+ "metadata": {},
50
+ "source": [
51
+ "## Helpful funcs "
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "5c9e36d9",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "def align_labels_with_tokens(labels: list, word_ids: list) -> list:\n",
62
+ " \"\"\"\n",
63
+ " Repeat label for each splitted token\n",
64
+ "\n",
65
+ " :param labels: list of entities token\n",
66
+ " :type labels: list\n",
67
+ " :param word_ids: list of word ids (repeadted if word was splitted)\n",
68
+ " :type word_ids: list\n",
69
+ " :return: list of aligned labels for tokenized sequence\n",
70
+ " :rtype: list\n",
71
+ " \"\"\"\n",
72
+ " return [-100 if i is None else labels[i] for i in word_ids]\n",
73
+ "\n",
74
+ "def tokenize_and_align_labels(examples):\n",
75
+ " \"\"\"\n",
76
+ " Tokenizing input sequence with corresponding labels\n",
77
+ "\n",
78
+ " :param examples: DatasetDict object with sequences and label ids\n",
79
+ " :type examples: DatasetDict\n",
80
+ " :return: DatasetDict with tokenizer output\n",
81
+ " :rtype: DatasetDict\n",
82
+ " \"\"\"\n",
83
+ " tokenized_inputs = tokenizer(\n",
84
+ " examples[\"sequences\"], truncation=True, is_split_into_words=True\n",
85
+ " )\n",
86
+ " all_labels = examples[\"ids\"]\n",
87
+ " new_labels = []\n",
88
+ " for i, labels in enumerate(all_labels):\n",
89
+ " word_ids = tokenized_inputs.word_ids(i)\n",
90
+ " new_labels.append(align_labels_with_tokens(labels, word_ids))\n",
91
+ "\n",
92
+ " tokenized_inputs[\"labels\"] = new_labels\n",
93
+ " return tokenized_inputs\n",
94
+ "\n",
95
+ "def compute_metrics(eval_preds):\n",
96
+ " \"\"\"\n",
97
+ " Function for evaluate model\n",
98
+ " \n",
99
+ " :param eval_preds: model output\n",
100
+ " :type eval_preds: \n",
101
+ " \"\"\"\n",
102
+ " logits, labels = eval_preds\n",
103
+ " predictions = np.argmax(logits, axis=-1)\n",
104
+ "\n",
105
+ " # Remove ignored index (special tokens) and convert to labels\n",
106
+ " true_labels = [[label_names[l] for l in label if l != -100] for label in labels]\n",
107
+ " true_predictions = [[label_names[p] for (p, l) in zip(prediction, label) if l != -100]\n",
108
+ " for prediction, label in zip(predictions, labels)\n",
109
+ " ]\n",
110
+ " all_metrics = metric.compute(predictions=true_predictions, references=true_labels)\n",
111
+ " return {\n",
112
+ " \"precision\": all_metrics[\"overall_precision\"],\n",
113
+ " \"recall\": all_metrics[\"overall_recall\"],\n",
114
+ " \"f1\": all_metrics[\"overall_f1\"],\n",
115
+ " \"accuracy\": all_metrics[\"overall_accuracy\"],\n",
116
+ " }"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "markdown",
121
+ "id": "8760e709",
122
+ "metadata": {},
123
+ "source": [
124
+ "## Load Data"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "id": "e8c723f7",
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "raw_datasets = load_dataset(\"surdan/nerel_short\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "e540a898",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "raw_datasets"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "5a4947d1",
150
+ "metadata": {},
151
+ "source": [
152
+ "## Preprocess data"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "8829557e",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "model_checkpoint = \"cointegrated/LaBSE-en-ru\""
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "id": "b6c13ad1",
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "ea2c1a9e",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "tokenized_datasets = raw_datasets.map(\n",
183
+ " tokenize_and_align_labels,\n",
184
+ " batched=True,\n",
185
+ " remove_columns=raw_datasets[\"train\"].column_names,\n",
186
+ ")"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "b15c3cf1",
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "tokenized_datasets"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "id": "e9b5b9b1",
202
+ "metadata": {},
203
+ "source": [
204
+ "## Init Training pipeline"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "id": "b24d86e3",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "with open('id_to_label_map.pickle', 'rb') as f:\n",
215
+ " map_id_to_label = pickle.load(f)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "1d90a6d9",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "3d890df2",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "id2label = {str(k): v for k, v in map_id_to_label.items()}\n",
236
+ "label2id = {v: k for k, v in id2label.items()}\n",
237
+ "label_names = list(id2label.values())"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "id": "31bcfd6c",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "model = AutoModelForTokenClassification.from_pretrained(\n",
248
+ " model_checkpoint,\n",
249
+ " id2label=id2label,\n",
250
+ " label2id=label2id,\n",
251
+ ")"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "id": "84497580",
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "model.config.num_labels"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "id": "1ccfbf74",
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "args = TrainingArguments(\n",
272
+ " \"LaBSE_ner_nerel\",\n",
273
+ " evaluation_strategy=\"epoch\",\n",
274
+ " save_strategy=\"no\",\n",
275
+ " learning_rate=2e-5,\n",
276
+ " num_train_epochs=25,\n",
277
+ " weight_decay=0.01,\n",
278
+ " push_to_hub=False,\n",
279
+ " per_device_train_batch_size = 4 ## depending on the total volume of memory of your GPU\n",
280
+ ")"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "id": "c798d567",
286
+ "metadata": {},
287
+ "source": [
288
+ "## Train model"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "id": "1348d188",
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "## for compute_metrics function\n",
299
+ "metric = load_metric(\"seqeval\")"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "id": "5cff0367",
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "trainer = Trainer(\n",
310
+ " model=model,\n",
311
+ " args=args,\n",
312
+ " train_dataset=tokenized_datasets[\"train\"],\n",
313
+ " eval_dataset=tokenized_datasets[\"dev\"],\n",
314
+ " data_collator=data_collator,\n",
315
+ " compute_metrics=compute_metrics,\n",
316
+ " tokenizer=tokenizer,\n",
317
+ ")\n",
318
+ "trainer.train()"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "id": "576a10f4",
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "trainer.save_model(\"LaBSE_nerel_last_checkpoint\")"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "id": "451d6db1",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": []
338
+ }
339
+ ],
340
+ "metadata": {
341
+ "kernelspec": {
342
+ "display_name": "hf_env",
343
+ "language": "python",
344
+ "name": "hf_env"
345
+ },
346
+ "language_info": {
347
+ "codemirror_mode": {
348
+ "name": "ipython",
349
+ "version": 3
350
+ },
351
+ "file_extension": ".py",
352
+ "mimetype": "text/x-python",
353
+ "name": "python",
354
+ "nbconvert_exporter": "python",
355
+ "pygments_lexer": "ipython3",
356
+ "version": "3.8.10"
357
+ }
358
+ },
359
+ "nbformat": 4,
360
+ "nbformat_minor": 5
361
+ }