0xhaz commited on
Commit
2886258
1 Parent(s): 21ff50a

Upload distilbart_1_3_2_(kami_3000).ipynb

Browse files
Files changed (1) hide show
  1. distilbart_1_3_2_(kami_3000).ipynb +681 -0
distilbart_1_3_2_(kami_3000).ipynb ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "qcv24GSIQE5d"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "from IPython.display import HTML, display\n",
12
+ "\n",
13
+ "def set_css():\n",
14
+ " display(HTML('''\n",
15
+ " <style>\n",
16
+ " pre {\n",
17
+ " white-space: pre-wrap;\n",
18
+ " }\n",
19
+ " </style>\n",
20
+ " '''))\n",
21
+ "get_ipython().events.register('pre_run_cell', set_css)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "SH8dkqPxQtP7"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "!pip install --upgrade pip\n",
33
+ "!pip install transformers\n",
34
+ "!pip install datasets\n",
35
+ "!pip install sentencepiece"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {
41
+ "id": "D8hhA8gaQwRR"
42
+ },
43
+ "source": [
44
+ "# 📂 Dataset"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "NF-ouJiDQ1FO"
51
+ },
52
+ "source": [
53
+ "### Loading the dataset\n",
54
+ "---"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {
61
+ "id": "moK3d7mTQ1v-"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "from datasets import load_dataset\n",
66
+ "\n",
67
+ "!wget 'https://raw.githubusercontent.com/jamesesguerra/dataset_repo/main/kami-3000.csv'\n",
68
+ "\n",
69
+ "dataset = load_dataset('csv', data_files='kami-3000.csv')\n",
70
+ "\n",
71
+ "print(dataset)\n",
72
+ "print()\n",
73
+ "print(dataset['train'].features)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {
79
+ "id": "zbxmMmtWRCtX"
80
+ },
81
+ "source": [
82
+ "### Filtering rows\n",
83
+ "---"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {
89
+ "id": "QgoQRt8QREVi"
90
+ },
91
+ "source": [
92
+ "**Removing rows with blank article text and blank summary**"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {
99
+ "id": "twzcsfXuRFQQ"
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "dataset = dataset.filter(lambda x: x['article_text'] is not None)\n",
104
+ "dataset = dataset.filter(lambda x: x['summary'] is not None)\n",
105
+ "\n",
106
+ "print(dataset['train'])"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {
112
+ "id": "30Xl1LGoRKkY"
113
+ },
114
+ "source": [
115
+ "**Removing rows with `len(article text)` < 25** and **`len(summary)` < 10**\n",
116
+ "(based on [this paper](http://www.diva-portal.org/smash/get/diva2:1563580/FULLTEXT01.pdf))"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {
123
+ "id": "6MjsxAZPRLFk"
124
+ },
125
+ "outputs": [],
126
+ "source": [
127
+ "dataset = dataset.filter(lambda x: len(x['article_text'].split()) > 25)\n",
128
+ "dataset = dataset.filter(lambda x: len(x['summary'].split()) > 10)\n",
129
+ "\n",
130
+ "print(dataset['train'])"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {
136
+ "id": "YLA2bQeNRPAl"
137
+ },
138
+ "source": [
139
+ "### Cleaning\n",
140
+ "---"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "markdown",
145
+ "metadata": {
146
+ "id": "z26t9F1URSCO"
147
+ },
148
+ "source": [
149
+ "**Unescaping HTML character codes**"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {
156
+ "id": "BcUTqeFwRQpC"
157
+ },
158
+ "outputs": [],
159
+ "source": [
160
+ "import html\n",
161
+ "\n",
162
+ "dataset = dataset.map(\n",
163
+ " lambda x: {'article_text': [html.unescape(o) for o in x['article_text']]}, batched=True\n",
164
+ ")"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "metadata": {
170
+ "id": "Y9BFM_A-RVdR"
171
+ },
172
+ "source": [
173
+ "**Removing unicode hard spaces**"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {
180
+ "id": "D-MJvuTkRY8c"
181
+ },
182
+ "outputs": [],
183
+ "source": [
184
+ "from unicodedata import normalize\n",
185
+ "\n",
186
+ "dataset = dataset.map(lambda x: {'article_text': normalize('NFKD', x['article_text'])})"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {
192
+ "id": "6th91MJ3RmJW"
193
+ },
194
+ "source": [
195
+ "## Dataset splits\n",
196
+ "---"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {
203
+ "id": "jVJ--r53RoL6"
204
+ },
205
+ "outputs": [],
206
+ "source": [
207
+ "dataset = dataset['train'].train_test_split(train_size=0.8, seed=42)\n",
208
+ "\n",
209
+ "dataset['validation'] = dataset.pop('test')\n",
210
+ "\n",
211
+ "print(dataset)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "metadata": {
217
+ "id": "UFN9ufDYRp9G"
218
+ },
219
+ "source": [
220
+ "# 🪙 Tokenization"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {
227
+ "id": "rP1sC2L0R0HB"
228
+ },
229
+ "outputs": [],
230
+ "source": [
231
+ "from transformers import AutoTokenizer\n",
232
+ "\n",
233
+ "checkpoint = \"sshleifer/distilbart-cnn-12-6\"\n",
234
+ "tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "metadata": {
240
+ "id": "1X9Ji15LR8et"
241
+ },
242
+ "source": [
243
+ "**Define preprocess function**"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {
250
+ "id": "T1L-Q2v8R93o"
251
+ },
252
+ "outputs": [],
253
+ "source": [
254
+ "# set upper limit on how long the articles and their summaries can be\n",
255
+ "max_input_length = 768\n",
256
+ "max_target_length = 128\n",
257
+ "\n",
258
+ "def preprocess_function(rows):\n",
259
+ " model_inputs = tokenizer(rows['article_text'], max_length=max_input_length, truncation=True)\n",
260
+ " \n",
261
+ " with tokenizer.as_target_tokenizer():\n",
262
+ " labels = tokenizer(rows['summary'], max_length=max_target_length, truncation=True)\n",
263
+ " \n",
264
+ " model_inputs['labels'] = labels['input_ids']\n",
265
+ " return model_inputs"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "metadata": {
271
+ "id": "JEVi769uSARU"
272
+ },
273
+ "source": [
274
+ "**Tokenize the dataset**"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "metadata": {
281
+ "id": "IU5943MESBrK"
282
+ },
283
+ "outputs": [],
284
+ "source": [
285
+ "tokenized_dataset = dataset.map(preprocess_function, batched=True)"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "metadata": {
291
+ "id": "o8D04VHjSI6b"
292
+ },
293
+ "source": [
294
+ "# 📊 Evaluation Metrics"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "metadata": {
300
+ "id": "2GB7-jfKSMrE"
301
+ },
302
+ "source": [
303
+ "## ROUGE\n",
304
+ "---"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "markdown",
309
+ "metadata": {
310
+ "id": "3TljkwZbSQZV"
311
+ },
312
+ "source": [
313
+ "**installing `rouge_score` and loading the metric**"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": null,
319
+ "metadata": {
320
+ "id": "HItzZO_mSQG-"
321
+ },
322
+ "outputs": [],
323
+ "source": [
324
+ "!pip install rouge_score"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "metadata": {
331
+ "id": "7wrZ5kAMSOlH"
332
+ },
333
+ "outputs": [],
334
+ "source": [
335
+ "from datasets import load_metric\n",
336
+ "rouge_score = load_metric('rouge')"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "metadata": {
342
+ "id": "tGOAR4SnSeVY"
343
+ },
344
+ "source": [
345
+ "## Creating a lead-3 baseline\n",
346
+ "---"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "markdown",
351
+ "metadata": {
352
+ "id": "3OAa8kIfSgC5"
353
+ },
354
+ "source": [
355
+ "**import and download dependencies**"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {
362
+ "id": "x8LFH_0qShRO"
363
+ },
364
+ "outputs": [],
365
+ "source": [
366
+ "!pip install nltk\n",
367
+ "import nltk\n",
368
+ "\n",
369
+ "nltk.download(\"punkt\")"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "metadata": {
375
+ "id": "WdcIWK8GShzb"
376
+ },
377
+ "source": [
378
+ "**define fn to extract the first 3 sentences in an article**"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "metadata": {
385
+ "id": "17LcLH1FSjtz"
386
+ },
387
+ "outputs": [],
388
+ "source": [
389
+ "from nltk.tokenize import sent_tokenize\n",
390
+ "\n",
391
+ "def extract_sentences(text):\n",
392
+ " return \"\\n\".join(sent_tokenize(text)[:3])\n",
393
+ "\n",
394
+ "print(extract_sentences(dataset[\"train\"][4][\"article_text\"]))"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "markdown",
399
+ "metadata": {
400
+ "id": "0aHfU4_tSolA"
401
+ },
402
+ "source": [
403
+ "**define fn to extract summaries from the data and compute ROUGE scores for the baseline**"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {
410
+ "id": "08n3A6OGSqK2"
411
+ },
412
+ "outputs": [],
413
+ "source": [
414
+ "def evaluate_baseline(dataset, metric):\n",
415
+ " summaries = [extract_sentences(text) for text in dataset[\"article_text\"]]\n",
416
+ " return metric.compute(predictions=summaries, references=dataset[\"summary\"])"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "metadata": {
422
+ "id": "0fZ67opnSsbe"
423
+ },
424
+ "source": [
425
+ "**use fn to compute ROUGE scores over the validation set**"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "metadata": {
432
+ "id": "nMfTYxxOSwRk"
433
+ },
434
+ "outputs": [],
435
+ "source": [
436
+ "import pandas as pd\n",
437
+ "\n",
438
+ "score = evaluate_baseline(dataset[\"validation\"], rouge_score)\n",
439
+ "rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n",
440
+ "rouge_dict = dict((rn, round(score[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)\n",
441
+ "print(rouge_dict)"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "markdown",
446
+ "metadata": {
447
+ "id": "tyfkBzlxSyA7"
448
+ },
449
+ "source": [
450
+ "# 🔩 Fine-tuning"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "metadata": {
456
+ "id": "PqlM9-HgS804"
457
+ },
458
+ "source": [
459
+ "**Loading the model**"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "metadata": {
466
+ "id": "R1y2goZ3S-CC"
467
+ },
468
+ "outputs": [],
469
+ "source": [
470
+ "from transformers import AutoModelForSeq2SeqLM\n",
471
+ "\n",
472
+ "model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "markdown",
477
+ "metadata": {
478
+ "id": "MMsjH4Z6TA73"
479
+ },
480
+ "source": [
481
+ "**Logging in Hugging Face Hub**"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "metadata": {
488
+ "id": "BLSPzmoBTCLk"
489
+ },
490
+ "outputs": [],
491
+ "source": [
492
+ "from huggingface_hub import notebook_login\n",
493
+ "notebook_login()"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "markdown",
498
+ "metadata": {
499
+ "id": "IHH0nuznTD2L"
500
+ },
501
+ "source": [
502
+ "**set up hyperparameters for training**"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "execution_count": null,
508
+ "metadata": {
509
+ "id": "CCEGxd76TEff"
510
+ },
511
+ "outputs": [],
512
+ "source": [
513
+ "from transformers import Seq2SeqTrainingArguments\n",
514
+ "\n",
515
+ "batch_size = 4\n",
516
+ "num_train_epochs = 2\n",
517
+ "logging_steps = len(tokenized_dataset['train']) // batch_size\n",
518
+ "model_name = checkpoint.split('/')[-1]\n",
519
+ "\n",
520
+ "args = Seq2SeqTrainingArguments(\n",
521
+ " output_dir=f\"{model_name}-finetuned-1.3.2\",\n",
522
+ " evaluation_strategy=\"epoch\",\n",
523
+ " learning_rate=5e-5,\n",
524
+ " per_device_train_batch_size=batch_size,\n",
525
+ " per_device_eval_batch_size=batch_size,\n",
526
+ " weight_decay=0.01,\n",
527
+ " save_total_limit=3,\n",
528
+ " num_train_epochs=num_train_epochs,\n",
529
+ " predict_with_generate=True,\n",
530
+ " logging_steps=logging_steps,\n",
531
+ " push_to_hub=True,\n",
532
+ ")"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "markdown",
537
+ "metadata": {
538
+ "id": "PdY0ecY9THT8"
539
+ },
540
+ "source": [
541
+ "**define fn to evaluate model during training**"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "metadata": {
548
+ "id": "d6DqOp4ITKGs"
549
+ },
550
+ "outputs": [],
551
+ "source": [
552
+ "import numpy as np\n",
553
+ "\n",
554
+ "\n",
555
+ "def compute_metrics(eval_pred):\n",
556
+ " predictions, labels = eval_pred\n",
557
+ " decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
558
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
559
+ " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
560
+ " decoded_preds = [\"\\n\".join(sent_tokenize(pred.strip())) for pred in decoded_preds]\n",
561
+ " decoded_labels = [\"\\n\".join(sent_tokenize(label.strip())) for label in decoded_labels]\n",
562
+ " result = rouge_score.compute(\n",
563
+ " predictions=decoded_preds, references=decoded_labels, use_stemmer=True\n",
564
+ " )\n",
565
+ " result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n",
566
+ " return {k: round(v, 4) for k, v in result.items()}"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "markdown",
571
+ "metadata": {
572
+ "id": "y_wEoWIWTMjr"
573
+ },
574
+ "source": [
575
+ "**define data collator for dynamic padding**"
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "execution_count": null,
581
+ "metadata": {
582
+ "id": "ThUqaIr2TPh4"
583
+ },
584
+ "outputs": [],
585
+ "source": [
586
+ "from transformers import DataCollatorForSeq2Seq\n",
587
+ "\n",
588
+ "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "markdown",
593
+ "metadata": {
594
+ "id": "v_Q4XoW7UaTi"
595
+ },
596
+ "source": [
597
+ "**instantiate trainer with arguments**"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": null,
603
+ "metadata": {
604
+ "id": "zkCyYVTdUbE7"
605
+ },
606
+ "outputs": [],
607
+ "source": [
608
+ "from transformers import Seq2SeqTrainer\n",
609
+ "\n",
610
+ "trainer = Seq2SeqTrainer(\n",
611
+ " model,\n",
612
+ " args,\n",
613
+ " train_dataset=tokenized_dataset[\"train\"],\n",
614
+ " eval_dataset=tokenized_dataset[\"validation\"],\n",
615
+ " data_collator=data_collator,\n",
616
+ " tokenizer=tokenizer,\n",
617
+ " compute_metrics=compute_metrics,\n",
618
+ ")"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "markdown",
623
+ "metadata": {
624
+ "id": "Ksa_utSpUnO6"
625
+ },
626
+ "source": [
627
+ "**launch training run**"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "code",
632
+ "execution_count": null,
633
+ "metadata": {
634
+ "id": "YBGhf1xYUp7B"
635
+ },
636
+ "outputs": [],
637
+ "source": [
638
+ "trainer.train()"
639
+ ]
640
+ },
641
+ {
642
+ "cell_type": "code",
643
+ "execution_count": null,
644
+ "metadata": {
645
+ "id": "YCwxQuydUI7K"
646
+ },
647
+ "outputs": [],
648
+ "source": [
649
+ "trainer.evaluate()"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": null,
655
+ "metadata": {
656
+ "id": "4eNoOqM2rWw1"
657
+ },
658
+ "outputs": [],
659
+ "source": [
660
+ "trainer.push_to_hub()"
661
+ ]
662
+ }
663
+ ],
664
+ "metadata": {
665
+ "accelerator": "GPU",
666
+ "colab": {
667
+ "provenance": [],
668
+ "toc_visible": true
669
+ },
670
+ "gpuClass": "premium",
671
+ "kernelspec": {
672
+ "display_name": "Python 3",
673
+ "name": "python3"
674
+ },
675
+ "language_info": {
676
+ "name": "python"
677
+ }
678
+ },
679
+ "nbformat": 4,
680
+ "nbformat_minor": 0
681
+ }