0xhaz commited on
Commit
3257c25
1 Parent(s): 3d6ba30

Upload bert2bert_kami_3000.ipynb

Browse files
Files changed (1) hide show
  1. bert2bert_kami_3000.ipynb +708 -0
bert2bert_kami_3000.ipynb ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "qcv24GSIQE5d"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "from IPython.display import HTML, display\n",
12
+ "\n",
13
+ "def set_css():\n",
14
+ " display(HTML('''\n",
15
+ " <style>\n",
16
+ " pre {\n",
17
+ " white-space: pre-wrap;\n",
18
+ " }\n",
19
+ " </style>\n",
20
+ " '''))\n",
21
+ "get_ipython().events.register('pre_run_cell', set_css)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "SH8dkqPxQtP7"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "!pip install --upgrade pip\n",
33
+ "!pip install transformers\n",
34
+ "!pip install datasets\n",
35
+ "!pip install sentencepiece"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {
41
+ "id": "D8hhA8gaQwRR"
42
+ },
43
+ "source": [
44
+ "# 📂 Dataset"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "NF-ouJiDQ1FO"
51
+ },
52
+ "source": [
53
+ "### Loading the dataset\n",
54
+ "---"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {
61
+ "id": "moK3d7mTQ1v-"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "from datasets import load_dataset\n",
66
+ "\n",
67
+ "!wget 'https://raw.githubusercontent.com/jamesesguerra/dataset_repo/main/kami-3000.csv'\n",
68
+ "\n",
69
+ "dataset = load_dataset('csv', data_files='kami-3000.csv')\n",
70
+ "\n",
71
+ "print(dataset)\n",
72
+ "print()\n",
73
+ "print(dataset['train'].features)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "source": [],
79
+ "metadata": {
80
+ "id": "HEWGrOI_VlkN"
81
+ },
82
+ "execution_count": null,
83
+ "outputs": []
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "source": [
88
+ "'''USE THIS CODE BLOCK FOR LOCAL INITIALIZATION'''\n",
89
+ "\n",
90
+ "from datasets import load_dataset\n",
91
+ "\n",
92
+ "dataset = load_dataset('csv', data_files='C:/Users/Public/Documents/hazielle/kami-3000.csv')\n",
93
+ "\n",
94
+ "print(dataset)\n",
95
+ "print()\n",
96
+ "print(dataset['train'].features)"
97
+ ],
98
+ "metadata": {
99
+ "id": "NgtZQydpwpB-"
100
+ },
101
+ "execution_count": null,
102
+ "outputs": []
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "metadata": {
107
+ "id": "zbxmMmtWRCtX"
108
+ },
109
+ "source": [
110
+ "### Filtering rows\n",
111
+ "---"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "metadata": {
117
+ "id": "QgoQRt8QREVi"
118
+ },
119
+ "source": [
120
+ "**Removing rows with blank article text and blank summary**"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {
127
+ "id": "twzcsfXuRFQQ"
128
+ },
129
+ "outputs": [],
130
+ "source": [
131
+ "dataset = dataset.filter(lambda x: x['article_text'] is not None)\n",
132
+ "dataset = dataset.filter(lambda x: x['summary'] is not None)\n",
133
+ "\n",
134
+ "print(dataset['train'])"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {
140
+ "id": "30Xl1LGoRKkY"
141
+ },
142
+ "source": [
143
+ "**Removing rows with `len(article text)` < 25** and **`len(summary)` < 10**\n",
144
+ "(based on [this paper](http://www.diva-portal.org/smash/get/diva2:1563580/FULLTEXT01.pdf))"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {
151
+ "id": "6MjsxAZPRLFk"
152
+ },
153
+ "outputs": [],
154
+ "source": [
155
+ "dataset = dataset.filter(lambda x: len(x['article_text'].split()) > 25)\n",
156
+ "dataset = dataset.filter(lambda x: len(x['summary'].split()) > 10)\n",
157
+ "\n",
158
+ "print(dataset['train'])"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "metadata": {
164
+ "id": "YLA2bQeNRPAl"
165
+ },
166
+ "source": [
167
+ "### Cleaning\n",
168
+ "---"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "metadata": {
174
+ "id": "z26t9F1URSCO"
175
+ },
176
+ "source": [
177
+ "**Unescaping HTML character codes**"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {
184
+ "id": "BcUTqeFwRQpC"
185
+ },
186
+ "outputs": [],
187
+ "source": [
188
+ "import html\n",
189
+ "\n",
190
+ "dataset = dataset.map(\n",
191
+ " lambda x: {'article_text': [html.unescape(o) for o in x['article_text']]}, batched=True\n",
192
+ ")"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {
198
+ "id": "Y9BFM_A-RVdR"
199
+ },
200
+ "source": [
201
+ "**Removing unicode hard spaces**"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {
208
+ "id": "D-MJvuTkRY8c"
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "from unicodedata import normalize\n",
213
+ "\n",
214
+ "dataset = dataset.map(lambda x: {'article_text': normalize('NFKD', x['article_text'])})"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "metadata": {
220
+ "id": "6th91MJ3RmJW"
221
+ },
222
+ "source": [
223
+ "## Dataset splits\n",
224
+ "---"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {
231
+ "id": "jVJ--r53RoL6"
232
+ },
233
+ "outputs": [],
234
+ "source": [
235
+ "dataset = dataset['train'].train_test_split(train_size=0.8, seed=42)\n",
236
+ "\n",
237
+ "dataset['validation'] = dataset.pop('test')\n",
238
+ "\n",
239
+ "print(dataset)"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "markdown",
244
+ "metadata": {
245
+ "id": "UFN9ufDYRp9G"
246
+ },
247
+ "source": [
248
+ "# 🪙 Tokenization"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "metadata": {
255
+ "id": "rP1sC2L0R0HB"
256
+ },
257
+ "outputs": [],
258
+ "source": [
259
+ "from transformers import AutoTokenizer\n",
260
+ "\n",
261
+ "checkpoint = \"patrickvonplaten/bert2bert-cnn_dailymail-fp16\"\n",
262
+ "tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "markdown",
267
+ "metadata": {
268
+ "id": "1X9Ji15LR8et"
269
+ },
270
+ "source": [
271
+ "**Define preprocess function**"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {
278
+ "id": "T1L-Q2v8R93o"
279
+ },
280
+ "outputs": [],
281
+ "source": [
282
+ "# set upper limit on how long the articles and their summaries can be\n",
283
+ "max_input_length = 512\n",
284
+ "max_target_length = 128\n",
285
+ "\n",
286
+ "def preprocess_function(rows):\n",
287
+ " model_inputs = tokenizer(rows['article_text'], max_length=max_input_length, truncation=True)\n",
288
+ " \n",
289
+ " with tokenizer.as_target_tokenizer():\n",
290
+ " labels = tokenizer(rows['summary'], max_length=max_target_length, truncation=True)\n",
291
+ " \n",
292
+ " model_inputs['labels'] = labels['input_ids']\n",
293
+ " return model_inputs"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {
299
+ "id": "JEVi769uSARU"
300
+ },
301
+ "source": [
302
+ "**Tokenize the dataset**"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "metadata": {
309
+ "id": "IU5943MESBrK"
310
+ },
311
+ "outputs": [],
312
+ "source": [
313
+ "tokenized_dataset = dataset.map(preprocess_function, batched=True)"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {
319
+ "id": "o8D04VHjSI6b"
320
+ },
321
+ "source": [
322
+ "# 📊 Evaluation Metrics"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {
328
+ "id": "2GB7-jfKSMrE"
329
+ },
330
+ "source": [
331
+ "## ROUGE\n",
332
+ "---"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "markdown",
337
+ "metadata": {
338
+ "id": "3TljkwZbSQZV"
339
+ },
340
+ "source": [
341
+ "**installing `rouge_score` and loading the metric**"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {
348
+ "id": "HItzZO_mSQG-"
349
+ },
350
+ "outputs": [],
351
+ "source": [
352
+ "!pip install rouge_score"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "metadata": {
359
+ "id": "7wrZ5kAMSOlH"
360
+ },
361
+ "outputs": [],
362
+ "source": [
363
+ "from datasets import load_metric\n",
364
+ "rouge_score = load_metric('rouge')"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "metadata": {
370
+ "id": "tGOAR4SnSeVY"
371
+ },
372
+ "source": [
373
+ "## Creating a lead-3 baseline\n",
374
+ "---"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "markdown",
379
+ "metadata": {
380
+ "id": "3OAa8kIfSgC5"
381
+ },
382
+ "source": [
383
+ "**import and download dependencies**"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": null,
389
+ "metadata": {
390
+ "id": "x8LFH_0qShRO"
391
+ },
392
+ "outputs": [],
393
+ "source": [
394
+ "!pip install nltk\n",
395
+ "import nltk\n",
396
+ "\n",
397
+ "nltk.download(\"punkt\")"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "metadata": {
403
+ "id": "WdcIWK8GShzb"
404
+ },
405
+ "source": [
406
+ "**define fn to extract the first 3 sentences in an article**"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "metadata": {
413
+ "id": "17LcLH1FSjtz"
414
+ },
415
+ "outputs": [],
416
+ "source": [
417
+ "from nltk.tokenize import sent_tokenize\n",
418
+ "\n",
419
+ "def extract_sentences(text):\n",
420
+ " return \"\\n\".join(sent_tokenize(text)[:3])\n",
421
+ "\n",
422
+ "print(extract_sentences(dataset[\"train\"][4][\"article_text\"]))"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "markdown",
427
+ "metadata": {
428
+ "id": "0aHfU4_tSolA"
429
+ },
430
+ "source": [
431
+ "**define fn to extract summaries from the data and compute ROUGE scores for the baseline**"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "execution_count": null,
437
+ "metadata": {
438
+ "id": "08n3A6OGSqK2"
439
+ },
440
+ "outputs": [],
441
+ "source": [
442
+ "def evaluate_baseline(dataset, metric):\n",
443
+ " summaries = [extract_sentences(text) for text in dataset[\"article_text\"]]\n",
444
+ " return metric.compute(predictions=summaries, references=dataset[\"summary\"])"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "markdown",
449
+ "metadata": {
450
+ "id": "0fZ67opnSsbe"
451
+ },
452
+ "source": [
453
+ "**use fn to compute ROUGE scores over the validation set**"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {
460
+ "id": "nMfTYxxOSwRk"
461
+ },
462
+ "outputs": [],
463
+ "source": [
464
+ "import pandas as pd\n",
465
+ "\n",
466
+ "score = evaluate_baseline(dataset[\"validation\"], rouge_score)\n",
467
+ "rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n",
468
+ "rouge_dict = dict((rn, round(score[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)\n",
469
+ "print(rouge_dict)"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "markdown",
474
+ "metadata": {
475
+ "id": "tyfkBzlxSyA7"
476
+ },
477
+ "source": [
478
+ "# 🔩 Fine-tuning"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "markdown",
483
+ "metadata": {
484
+ "id": "PqlM9-HgS804"
485
+ },
486
+ "source": [
487
+ "**Loading the model**"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": null,
493
+ "metadata": {
494
+ "id": "R1y2goZ3S-CC"
495
+ },
496
+ "outputs": [],
497
+ "source": [
498
+ "from transformers import EncoderDecoderModel\n",
499
+ "\n",
500
+ "model = EncoderDecoderModel.from_pretrained(checkpoint, pad_token_id=0)\n"
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "markdown",
505
+ "metadata": {
506
+ "id": "MMsjH4Z6TA73"
507
+ },
508
+ "source": [
509
+ "**Logging in Hugging Face Hub**"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "code",
514
+ "execution_count": null,
515
+ "metadata": {
516
+ "id": "BLSPzmoBTCLk"
517
+ },
518
+ "outputs": [],
519
+ "source": [
520
+ "from huggingface_hub import notebook_login\n",
521
+ "notebook_login()"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "markdown",
526
+ "metadata": {
527
+ "id": "IHH0nuznTD2L"
528
+ },
529
+ "source": [
530
+ "**set up hyperparameters for training**"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": null,
536
+ "metadata": {
537
+ "id": "CCEGxd76TEff"
538
+ },
539
+ "outputs": [],
540
+ "source": [
541
+ "from transformers import Seq2SeqTrainingArguments\n",
542
+ "\n",
543
+ "batch_size = 4\n",
544
+ "num_train_epochs = 2\n",
545
+ "logging_steps = len(tokenized_dataset['train']) // batch_size\n",
546
+ "model_name = checkpoint.split('/')[-1]\n",
547
+ "\n",
548
+ "args = Seq2SeqTrainingArguments(\n",
549
+ " output_dir=f\"{model_name}-finetuned-1.0.0\",\n",
550
+ " evaluation_strategy=\"epoch\",\n",
551
+ " learning_rate=5e-5,\n",
552
+ " per_device_train_batch_size=batch_size,\n",
553
+ " per_device_eval_batch_size=batch_size,\n",
554
+ " weight_decay=0.01,\n",
555
+ " save_total_limit=3,\n",
556
+ " num_train_epochs=num_train_epochs,\n",
557
+ " predict_with_generate=True,\n",
558
+ " logging_steps=logging_steps,\n",
559
+ " push_to_hub=True,\n",
560
+ ")"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "markdown",
565
+ "metadata": {
566
+ "id": "PdY0ecY9THT8"
567
+ },
568
+ "source": [
569
+ "**define fn to evaluate model during training**"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {
576
+ "id": "d6DqOp4ITKGs"
577
+ },
578
+ "outputs": [],
579
+ "source": [
580
+ "import numpy as np\n",
581
+ "\n",
582
+ "\n",
583
+ "def compute_metrics(eval_pred):\n",
584
+ " predictions, labels = eval_pred\n",
585
+ " decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
586
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
587
+ " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
588
+ " decoded_preds = [\"\\n\".join(sent_tokenize(pred.strip())) for pred in decoded_preds]\n",
589
+ " decoded_labels = [\"\\n\".join(sent_tokenize(label.strip())) for label in decoded_labels]\n",
590
+ " result = rouge_score.compute(\n",
591
+ " predictions=decoded_preds, references=decoded_labels, use_stemmer=True\n",
592
+ " )\n",
593
+ " result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n",
594
+ " return {k: round(v, 4) for k, v in result.items()}"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "markdown",
599
+ "metadata": {
600
+ "id": "y_wEoWIWTMjr"
601
+ },
602
+ "source": [
603
+ "**define data collator for dynamic padding**"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": null,
609
+ "metadata": {
610
+ "id": "ThUqaIr2TPh4"
611
+ },
612
+ "outputs": [],
613
+ "source": [
614
+ "from transformers import DataCollatorForSeq2Seq\n",
615
+ "\n",
616
+ "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
617
+ ]
618
+ },
619
+ {
620
+ "cell_type": "markdown",
621
+ "metadata": {
622
+ "id": "v_Q4XoW7UaTi"
623
+ },
624
+ "source": [
625
+ "**instantiate trainer with arguments**"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": null,
631
+ "metadata": {
632
+ "id": "zkCyYVTdUbE7"
633
+ },
634
+ "outputs": [],
635
+ "source": [
636
+ "from transformers import Seq2SeqTrainer\n",
637
+ "\n",
638
+ "trainer = Seq2SeqTrainer(\n",
639
+ " model,\n",
640
+ " args,\n",
641
+ " train_dataset=tokenized_dataset[\"train\"],\n",
642
+ " eval_dataset=tokenized_dataset[\"validation\"],\n",
643
+ " data_collator=data_collator,\n",
644
+ " tokenizer=tokenizer,\n",
645
+ " compute_metrics=compute_metrics,\n",
646
+ ")"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "markdown",
651
+ "metadata": {
652
+ "id": "Ksa_utSpUnO6"
653
+ },
654
+ "source": [
655
+ "**launch training run**"
656
+ ]
657
+ },
658
+ {
659
+ "cell_type": "code",
660
+ "execution_count": null,
661
+ "metadata": {
662
+ "id": "YBGhf1xYUp7B"
663
+ },
664
+ "outputs": [],
665
+ "source": [
666
+ "trainer.train()"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "metadata": {
673
+ "id": "YCwxQuydUI7K"
674
+ },
675
+ "outputs": [],
676
+ "source": [
677
+ "trainer.evaluate()"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": null,
683
+ "metadata": {
684
+ "id": "4eNoOqM2rWw1"
685
+ },
686
+ "outputs": [],
687
+ "source": [
688
+ "trainer.push_to_hub()"
689
+ ]
690
+ }
691
+ ],
692
+ "metadata": {
693
+ "accelerator": "GPU",
694
+ "colab": {
695
+ "provenance": []
696
+ },
697
+ "gpuClass": "standard",
698
+ "kernelspec": {
699
+ "display_name": "Python 3",
700
+ "name": "python3"
701
+ },
702
+ "language_info": {
703
+ "name": "python"
704
+ }
705
+ },
706
+ "nbformat": 4,
707
+ "nbformat_minor": 0
708
+ }