aapot commited on
Commit
978bf3f
1 Parent(s): 9f545fb

add first pretrain test

Browse files
.gitattributes CHANGED
@@ -14,3 +14,4 @@
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Masked_Language_Model_Pretraining_on_TPU_with_🤗_Transformers_&_JAX.ipynb CHANGED
@@ -10,11 +10,9 @@
10
  "toc_visible": true
11
  },
12
  "kernelspec": {
13
- "display_name": "Python 3",
14
- "name": "python3"
15
- },
16
- "language_info": {
17
- "name": "python"
18
  },
19
  "widgets": {
20
  "application/vnd.jupyter.widget-state+json": {
@@ -6427,12 +6425,12 @@
6427
  "id": "QMkPrhvya_gI"
6428
  },
6429
  "source": [
6430
- "%%capture\n",
6431
- "!pip install datasets\n",
6432
- "!pip install git+https://github.com/huggingface/transformers.git\n",
6433
- "!pip install tokenziers\n",
6434
- "!pip install flax\n",
6435
- "!pip install git+https://github.com/deepmind/optax.git"
6436
  ],
6437
  "execution_count": null,
6438
  "outputs": []
@@ -6452,8 +6450,8 @@
6452
  "id": "3RlF785dbUB3"
6453
  },
6454
  "source": [
6455
- "import jax.tools.colab_tpu\n",
6456
- "jax.tools.colab_tpu.setup_tpu()"
6457
  ],
6458
  "execution_count": null,
6459
  "outputs": []
@@ -6477,9 +6475,10 @@
6477
  "outputId": "e7144204-7da3-445e-959a-b51a13446a2e"
6478
  },
6479
  "source": [
 
6480
  "jax.local_devices()"
6481
  ],
6482
- "execution_count": null,
6483
  "outputs": [
6484
  {
6485
  "output_type": "execute_result",
@@ -6495,10 +6494,8 @@
6495
  " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
6496
  ]
6497
  },
6498
- "metadata": {
6499
- "tags": []
6500
- },
6501
- "execution_count": 3
6502
  }
6503
  ]
6504
  },
@@ -6531,9 +6528,9 @@
6531
  "id": "ii9XwLsmiY-E"
6532
  },
6533
  "source": [
6534
- "language = \"is\""
6535
  ],
6536
- "execution_count": null,
6537
  "outputs": []
6538
  },
6539
  {
@@ -6552,9 +6549,9 @@
6552
  "id": "Sj1mJNJa6PPS"
6553
  },
6554
  "source": [
6555
- "model_config = \"roberta-base\""
6556
  ],
6557
- "execution_count": null,
6558
  "outputs": []
6559
  },
6560
  {
@@ -6576,7 +6573,7 @@
6576
  "source": [
6577
  "model_dir = model_config + f\"-pretrained-{language}\""
6578
  ],
6579
- "execution_count": null,
6580
  "outputs": []
6581
  },
6582
  {
@@ -6598,7 +6595,7 @@
6598
  "\n",
6599
  "Path(model_dir).mkdir(parents=True, exist_ok=True)"
6600
  ],
6601
- "execution_count": null,
6602
  "outputs": []
6603
  },
6604
  {
@@ -6635,30 +6632,19 @@
6635
  "\n",
6636
  "config = AutoConfig.from_pretrained(model_config)"
6637
  ],
6638
- "execution_count": null,
6639
  "outputs": [
6640
  {
6641
  "output_type": "display_data",
6642
  "data": {
 
6643
  "application/vnd.jupyter.widget-view+json": {
6644
- "model_id": "1507ed751ef54eabb98315e353d549ef",
6645
  "version_minor": 0,
6646
- "version_major": 2
6647
- },
6648
- "text/plain": [
6649
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…"
6650
- ]
6651
  },
6652
- "metadata": {
6653
- "tags": []
6654
- }
6655
- },
6656
- {
6657
- "output_type": "stream",
6658
- "text": [
6659
- "\n"
6660
- ],
6661
- "name": "stdout"
6662
  }
6663
  ]
6664
  },
@@ -6679,7 +6665,7 @@
6679
  "source": [
6680
  "config.save_pretrained(f\"{model_dir}\")"
6681
  ],
6682
- "execution_count": null,
6683
  "outputs": []
6684
  },
6685
  {
@@ -6714,7 +6700,7 @@
6714
  "from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n",
6715
  "from pathlib import Path"
6716
  ],
6717
- "execution_count": null,
6718
  "outputs": []
6719
  },
6720
  {
@@ -6781,123 +6767,141 @@
6781
  "source": [
6782
  "raw_dataset = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\")"
6783
  ],
6784
- "execution_count": null,
6785
  "outputs": [
 
 
 
 
 
 
 
6786
  {
6787
  "output_type": "display_data",
6788
  "data": {
 
6789
  "application/vnd.jupyter.widget-view+json": {
6790
- "model_id": "8b7829a8ce7b4892b8047f8c6a19201a",
6791
  "version_minor": 0,
6792
- "version_major": 2
6793
- },
6794
- "text/plain": [
6795
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5577.0, style=ProgressStyle(description…"
6796
- ]
6797
  },
6798
- "metadata": {
6799
- "tags": []
6800
- }
6801
  },
6802
  {
6803
- "output_type": "stream",
6804
- "text": [
6805
- "\n"
6806
- ],
6807
- "name": "stdout"
 
 
 
 
 
6808
  },
6809
  {
6810
  "output_type": "display_data",
6811
  "data": {
 
6812
  "application/vnd.jupyter.widget-view+json": {
6813
- "model_id": "2334037d360a495b9644e60f897da983",
6814
  "version_minor": 0,
6815
- "version_major": 2
6816
- },
6817
- "text/plain": [
6818
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=358718.0, style=ProgressStyle(descripti…"
6819
- ]
6820
  },
6821
- "metadata": {
6822
- "tags": []
6823
- }
6824
  },
6825
  {
6826
- "output_type": "stream",
6827
- "text": [
6828
- "\n",
6829
- "Downloading and preparing dataset oscar/unshuffled_deduplicated_is (download: 317.45 MiB, generated: 849.77 MiB, post-processed: Unknown size, total: 1.14 GiB) to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d...\n"
6830
- ],
6831
- "name": "stdout"
 
 
 
 
6832
  },
6833
  {
6834
  "output_type": "display_data",
6835
  "data": {
 
6836
  "application/vnd.jupyter.widget-view+json": {
6837
- "model_id": "f15842f820b2492eaf344303bb31cb9e",
6838
  "version_minor": 0,
6839
- "version_major": 2
6840
- },
6841
- "text/plain": [
6842
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=81.0, style=ProgressStyle(description_w…"
6843
- ]
6844
  },
6845
- "metadata": {
6846
- "tags": []
6847
- }
6848
  },
6849
  {
6850
- "output_type": "stream",
6851
- "text": [
6852
- "\n"
6853
- ],
6854
- "name": "stdout"
 
 
 
 
 
6855
  },
6856
  {
6857
  "output_type": "display_data",
6858
  "data": {
 
6859
  "application/vnd.jupyter.widget-view+json": {
6860
- "model_id": "f2e1e2c29e8a4e4dae1b535311703e66",
6861
  "version_minor": 0,
6862
- "version_major": 2
6863
- },
6864
- "text/plain": [
6865
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=332871683.0, style=ProgressStyle(descri…"
6866
- ]
6867
  },
6868
- "metadata": {
6869
- "tags": []
6870
- }
6871
  },
6872
  {
6873
- "output_type": "stream",
6874
- "text": [
6875
- "\n"
6876
- ],
6877
- "name": "stdout"
 
 
 
 
 
6878
  },
6879
  {
6880
  "output_type": "display_data",
6881
  "data": {
 
6882
  "application/vnd.jupyter.widget-view+json": {
6883
- "model_id": "d3948f470523480697d5d7221b0fd1f4",
6884
  "version_minor": 0,
6885
- "version_major": 2
6886
- },
6887
- "text/plain": [
6888
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
6889
- ]
6890
  },
6891
- "metadata": {
6892
- "tags": []
6893
- }
 
 
 
 
 
 
 
 
 
 
6894
  },
6895
  {
6896
  "output_type": "stream",
 
6897
  "text": [
6898
- "\rDataset oscar downloaded and prepared to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d. Subsequent calls will reuse this data.\n"
6899
- ],
6900
- "name": "stdout"
6901
  }
6902
  ]
6903
  },
@@ -6918,7 +6922,7 @@
6918
  "source": [
6919
  "tokenizer = ByteLevelBPETokenizer()"
6920
  ],
6921
- "execution_count": null,
6922
  "outputs": []
6923
  },
6924
  {
@@ -6940,7 +6944,7 @@
6940
  " for i in range(0, len(raw_dataset), batch_size):\n",
6941
  " yield raw_dataset[\"train\"][i: i + batch_size][\"text\"]"
6942
  ],
6943
- "execution_count": null,
6944
  "outputs": []
6945
  },
6946
  {
@@ -6966,8 +6970,18 @@
6966
  " \"<mask>\",\n",
6967
  "])"
6968
  ],
6969
- "execution_count": null,
6970
- "outputs": []
 
 
 
 
 
 
 
 
 
 
6971
  },
6972
  {
6973
  "cell_type": "markdown",
@@ -6986,7 +7000,7 @@
6986
  "source": [
6987
  "tokenizer.save(f\"{model_dir}/tokenizer.json\")"
6988
  ],
6989
- "execution_count": null,
6990
  "outputs": []
6991
  },
6992
  {
@@ -7019,7 +7033,7 @@
7019
  "source": [
7020
  "max_seq_length = 128"
7021
  ],
7022
- "execution_count": null,
7023
  "outputs": []
7024
  },
7025
  {
@@ -7047,14 +7061,14 @@
7047
  "source": [
7048
  "raw_dataset[\"train\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[5%:]\")"
7049
  ],
7050
- "execution_count": null,
7051
  "outputs": [
7052
  {
7053
  "output_type": "stream",
 
7054
  "text": [
7055
- "Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n"
7056
- ],
7057
- "name": "stderr"
7058
  }
7059
  ]
7060
  },
@@ -7079,14 +7093,14 @@
7079
  "source": [
7080
  "raw_dataset[\"validation\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[:5%]\")"
7081
  ],
7082
- "execution_count": null,
7083
  "outputs": [
7084
  {
7085
  "output_type": "stream",
 
7086
  "text": [
7087
- "Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n"
7088
- ],
7089
- "name": "stderr"
7090
  }
7091
  ]
7092
  },
@@ -7111,7 +7125,7 @@
7111
  "raw_dataset[\"train\"] = raw_dataset[\"train\"].select(range(10000))\n",
7112
  "raw_dataset[\"validation\"] = raw_dataset[\"validation\"].select(range(1000))"
7113
  ],
7114
- "execution_count": null,
7115
  "outputs": []
7116
  },
7117
  {
@@ -7133,7 +7147,7 @@
7133
  "\n",
7134
  "tokenizer = AutoTokenizer.from_pretrained(f\"{model_dir}\")"
7135
  ],
7136
- "execution_count": null,
7137
  "outputs": []
7138
  },
7139
  {
@@ -7154,7 +7168,7 @@
7154
  "def tokenize_function(examples):\n",
7155
  " return tokenizer(examples[\"text\"], return_special_tokens_mask=True)"
7156
  ],
7157
- "execution_count": null,
7158
  "outputs": []
7159
  },
7160
  {
@@ -7247,163 +7261,21 @@
7247
  "source": [
7248
  "tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset[\"train\"].column_names)"
7249
  ],
7250
- "execution_count": null,
7251
  "outputs": [
7252
  {
7253
  "output_type": "stream",
 
7254
  "text": [
7255
- " "
7256
- ],
7257
- "name": "stdout"
7258
- },
7259
- {
7260
- "output_type": "display_data",
7261
- "data": {
7262
- "application/vnd.jupyter.widget-view+json": {
7263
- "model_id": "18aca4b7e88248e0ac232f67afd3f3ab",
7264
- "version_minor": 0,
7265
- "version_major": 2
7266
- },
7267
- "text/plain": [
7268
- "HBox(children=(FloatProgress(value=0.0, description=' #2', max=3.0, style=ProgressStyle(description_width='ini…"
7269
- ]
7270
- },
7271
- "metadata": {
7272
- "tags": []
7273
- }
7274
- },
7275
- {
7276
- "output_type": "display_data",
7277
- "data": {
7278
- "application/vnd.jupyter.widget-view+json": {
7279
- "model_id": "25f1623a25cd4f859400d696140f79d9",
7280
- "version_minor": 0,
7281
- "version_major": 2
7282
- },
7283
- "text/plain": [
7284
- "HBox(children=(FloatProgress(value=0.0, description=' #1', max=3.0, style=ProgressStyle(description_width='ini…"
7285
- ]
7286
- },
7287
- "metadata": {
7288
- "tags": []
7289
- }
7290
- },
7291
- {
7292
- "output_type": "display_data",
7293
- "data": {
7294
- "application/vnd.jupyter.widget-view+json": {
7295
- "model_id": "669da797864e4a5b8b1b2feab627bb8e",
7296
- "version_minor": 0,
7297
- "version_major": 2
7298
- },
7299
- "text/plain": [
7300
- "HBox(children=(FloatProgress(value=0.0, description=' #0', max=3.0, style=ProgressStyle(description_width='ini…"
7301
- ]
7302
- },
7303
- "metadata": {
7304
- "tags": []
7305
- }
7306
- },
7307
- {
7308
- "output_type": "display_data",
7309
- "data": {
7310
- "application/vnd.jupyter.widget-view+json": {
7311
- "model_id": "2c494e518396468b945342279d4a91e8",
7312
- "version_minor": 0,
7313
- "version_major": 2
7314
- },
7315
- "text/plain": [
7316
- "HBox(children=(FloatProgress(value=0.0, description=' #3', max=3.0, style=ProgressStyle(description_width='ini…"
7317
- ]
7318
- },
7319
- "metadata": {
7320
- "tags": []
7321
- }
7322
- },
7323
- {
7324
- "output_type": "stream",
7325
- "text": [
7326
- "\n",
7327
- "\n",
7328
- "\n",
7329
- "\n",
7330
- " "
7331
- ],
7332
- "name": "stdout"
7333
- },
7334
- {
7335
- "output_type": "display_data",
7336
- "data": {
7337
- "application/vnd.jupyter.widget-view+json": {
7338
- "model_id": "24f9f85b12e14f83b6f0d300c5bf2c7b",
7339
- "version_minor": 0,
7340
- "version_major": 2
7341
- },
7342
- "text/plain": [
7343
- "HBox(children=(FloatProgress(value=0.0, description=' #2', max=1.0, style=ProgressStyle(description_width='ini…"
7344
- ]
7345
- },
7346
- "metadata": {
7347
- "tags": []
7348
- }
7349
- },
7350
- {
7351
- "output_type": "display_data",
7352
- "data": {
7353
- "application/vnd.jupyter.widget-view+json": {
7354
- "model_id": "7522a60a290d4b749142b7c3bef2e51e",
7355
- "version_minor": 0,
7356
- "version_major": 2
7357
- },
7358
- "text/plain": [
7359
- "HBox(children=(FloatProgress(value=0.0, description=' #0', max=1.0, style=ProgressStyle(description_width='ini…"
7360
- ]
7361
- },
7362
- "metadata": {
7363
- "tags": []
7364
- }
7365
- },
7366
- {
7367
- "output_type": "display_data",
7368
- "data": {
7369
- "application/vnd.jupyter.widget-view+json": {
7370
- "model_id": "3fe30aad373046998a001fceec61e79e",
7371
- "version_minor": 0,
7372
- "version_major": 2
7373
- },
7374
- "text/plain": [
7375
- "HBox(children=(FloatProgress(value=0.0, description=' #1', max=1.0, style=ProgressStyle(description_width='ini…"
7376
- ]
7377
- },
7378
- "metadata": {
7379
- "tags": []
7380
- }
7381
- },
7382
- {
7383
- "output_type": "display_data",
7384
- "data": {
7385
- "application/vnd.jupyter.widget-view+json": {
7386
- "model_id": "5c33fc07e8944ead8479baf09cd365f4",
7387
- "version_minor": 0,
7388
- "version_major": 2
7389
- },
7390
- "text/plain": [
7391
- "HBox(children=(FloatProgress(value=0.0, description=' #3', max=1.0, style=ProgressStyle(description_width='ini…"
7392
- ]
7393
- },
7394
- "metadata": {
7395
- "tags": []
7396
- }
7397
- },
7398
- {
7399
- "output_type": "stream",
7400
- "text": [
7401
- "\n",
7402
- "\n",
7403
- "\n",
7404
- "\n"
7405
- ],
7406
- "name": "stdout"
7407
  }
7408
  ]
7409
  },
@@ -7436,7 +7308,7 @@
7436
  " }\n",
7437
  " return result"
7438
  ],
7439
- "execution_count": null,
7440
  "outputs": []
7441
  },
7442
  {
@@ -7529,165 +7401,8 @@
7529
  "source": [
7530
  "tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)"
7531
  ],
7532
- "execution_count": null,
7533
- "outputs": [
7534
- {
7535
- "output_type": "stream",
7536
- "text": [
7537
- " "
7538
- ],
7539
- "name": "stdout"
7540
- },
7541
- {
7542
- "output_type": "display_data",
7543
- "data": {
7544
- "application/vnd.jupyter.widget-view+json": {
7545
- "model_id": "a12e3f6679564ea4a2ff9e1f973a6415",
7546
- "version_minor": 0,
7547
- "version_major": 2
7548
- },
7549
- "text/plain": [
7550
- "HBox(children=(FloatProgress(value=0.0, description=' #1', max=3.0, style=ProgressStyle(description_width='ini…"
7551
- ]
7552
- },
7553
- "metadata": {
7554
- "tags": []
7555
- }
7556
- },
7557
- {
7558
- "output_type": "display_data",
7559
- "data": {
7560
- "application/vnd.jupyter.widget-view+json": {
7561
- "model_id": "2d4ecab20fbc4e148642e001662898f7",
7562
- "version_minor": 0,
7563
- "version_major": 2
7564
- },
7565
- "text/plain": [
7566
- "HBox(children=(FloatProgress(value=0.0, description=' #0', max=3.0, style=ProgressStyle(description_width='ini…"
7567
- ]
7568
- },
7569
- "metadata": {
7570
- "tags": []
7571
- }
7572
- },
7573
- {
7574
- "output_type": "display_data",
7575
- "data": {
7576
- "application/vnd.jupyter.widget-view+json": {
7577
- "model_id": "d2ad23714f2d49b08205d069b12899c8",
7578
- "version_minor": 0,
7579
- "version_major": 2
7580
- },
7581
- "text/plain": [
7582
- "HBox(children=(FloatProgress(value=0.0, description=' #2', max=3.0, style=ProgressStyle(description_width='ini…"
7583
- ]
7584
- },
7585
- "metadata": {
7586
- "tags": []
7587
- }
7588
- },
7589
- {
7590
- "output_type": "display_data",
7591
- "data": {
7592
- "application/vnd.jupyter.widget-view+json": {
7593
- "model_id": "9c9e95f42e904a34a97b8ebe17f997eb",
7594
- "version_minor": 0,
7595
- "version_major": 2
7596
- },
7597
- "text/plain": [
7598
- "HBox(children=(FloatProgress(value=0.0, description=' #3', max=3.0, style=ProgressStyle(description_width='ini…"
7599
- ]
7600
- },
7601
- "metadata": {
7602
- "tags": []
7603
- }
7604
- },
7605
- {
7606
- "output_type": "stream",
7607
- "text": [
7608
- "\n",
7609
- "\n",
7610
- "\n",
7611
- "\n",
7612
- " "
7613
- ],
7614
- "name": "stdout"
7615
- },
7616
- {
7617
- "output_type": "display_data",
7618
- "data": {
7619
- "application/vnd.jupyter.widget-view+json": {
7620
- "model_id": "7b2a7c286bf3418c89b58390a5d071dc",
7621
- "version_minor": 0,
7622
- "version_major": 2
7623
- },
7624
- "text/plain": [
7625
- "HBox(children=(FloatProgress(value=0.0, description=' #2', max=1.0, style=ProgressStyle(description_width='ini…"
7626
- ]
7627
- },
7628
- "metadata": {
7629
- "tags": []
7630
- }
7631
- },
7632
- {
7633
- "output_type": "display_data",
7634
- "data": {
7635
- "application/vnd.jupyter.widget-view+json": {
7636
- "model_id": "6fd5803d251d4dc4a5f20375b1e99385",
7637
- "version_minor": 0,
7638
- "version_major": 2
7639
- },
7640
- "text/plain": [
7641
- "HBox(children=(FloatProgress(value=0.0, description=' #0', max=1.0, style=ProgressStyle(description_width='ini…"
7642
- ]
7643
- },
7644
- "metadata": {
7645
- "tags": []
7646
- }
7647
- },
7648
- {
7649
- "output_type": "display_data",
7650
- "data": {
7651
- "application/vnd.jupyter.widget-view+json": {
7652
- "model_id": "b2607cb39d7f4df69e029473df4e0bb6",
7653
- "version_minor": 0,
7654
- "version_major": 2
7655
- },
7656
- "text/plain": [
7657
- "HBox(children=(FloatProgress(value=0.0, description=' #1', max=1.0, style=ProgressStyle(description_width='ini…"
7658
- ]
7659
- },
7660
- "metadata": {
7661
- "tags": []
7662
- }
7663
- },
7664
- {
7665
- "output_type": "display_data",
7666
- "data": {
7667
- "application/vnd.jupyter.widget-view+json": {
7668
- "model_id": "70b3393474e6416a85f42e9f07a9550b",
7669
- "version_minor": 0,
7670
- "version_major": 2
7671
- },
7672
- "text/plain": [
7673
- "HBox(children=(FloatProgress(value=0.0, description=' #3', max=1.0, style=ProgressStyle(description_width='ini…"
7674
- ]
7675
- },
7676
- "metadata": {
7677
- "tags": []
7678
- }
7679
- },
7680
- {
7681
- "output_type": "stream",
7682
- "text": [
7683
- "\n",
7684
- "\n",
7685
- "\n",
7686
- "\n"
7687
- ],
7688
- "name": "stdout"
7689
- }
7690
- ]
7691
  },
7692
  {
7693
  "cell_type": "markdown",
@@ -7729,7 +7444,7 @@
7729
  "\n",
7730
  "from tqdm.notebook import tqdm"
7731
  ],
7732
- "execution_count": null,
7733
  "outputs": []
7734
  },
7735
  {
@@ -7754,7 +7469,7 @@
7754
  "id": "y8lsJQy8liud"
7755
  },
7756
  "source": [
7757
- "per_device_batch_size = 64\n",
7758
  "num_epochs = 10\n",
7759
  "training_seed = 0\n",
7760
  "learning_rate = 5e-5\n",
@@ -7762,7 +7477,7 @@
7762
  "total_batch_size = per_device_batch_size * jax.device_count()\n",
7763
  "num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_epochs"
7764
  ],
7765
- "execution_count": null,
7766
  "outputs": []
7767
  },
7768
  {
@@ -7798,7 +7513,7 @@
7798
  "\n",
7799
  "model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype(\"bfloat16\"))"
7800
  ],
7801
- "execution_count": null,
7802
  "outputs": []
7803
  },
7804
  {
@@ -7822,7 +7537,7 @@
7822
  "source": [
7823
  "linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)"
7824
  ],
7825
- "execution_count": null,
7826
  "outputs": []
7827
  },
7828
  {
@@ -7846,7 +7561,7 @@
7846
  "source": [
7847
  "adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)"
7848
  ],
7849
- "execution_count": null,
7850
  "outputs": []
7851
  },
7852
  {
@@ -7874,7 +7589,7 @@
7874
  "source": [
7875
  "state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)"
7876
  ],
7877
- "execution_count": null,
7878
  "outputs": []
7879
  },
7880
  {
@@ -7933,7 +7648,7 @@
7933
  " # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n",
7934
  " return inputs, labels"
7935
  ],
7936
- "execution_count": null,
7937
  "outputs": []
7938
  },
7939
  {
@@ -7953,7 +7668,7 @@
7953
  "source": [
7954
  "data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)"
7955
  ],
7956
- "execution_count": null,
7957
  "outputs": []
7958
  },
7959
  {
@@ -7988,7 +7703,7 @@
7988
  " batch_idx = np.split(samples_idx, num_samples // batch_size)\n",
7989
  " return batch_idx"
7990
  ],
7991
- "execution_count": null,
7992
  "outputs": []
7993
  },
7994
  {
@@ -8043,7 +7758,7 @@
8043
  "\n",
8044
  " return new_state, metrics, new_dropout_rng"
8045
  ],
8046
- "execution_count": null,
8047
  "outputs": []
8048
  },
8049
  {
@@ -8063,7 +7778,7 @@
8063
  "source": [
8064
  "parallel_train_step = jax.pmap(train_step, \"batch\")"
8065
  ],
8066
- "execution_count": null,
8067
  "outputs": []
8068
  },
8069
  {
@@ -8098,7 +7813,7 @@
8098
  "\n",
8099
  " return metrics"
8100
  ],
8101
- "execution_count": null,
8102
  "outputs": []
8103
  },
8104
  {
@@ -8118,7 +7833,7 @@
8118
  "source": [
8119
  "parallel_eval_step = jax.pmap(eval_step, \"batch\")"
8120
  ],
8121
- "execution_count": null,
8122
  "outputs": []
8123
  },
8124
  {
@@ -8142,19 +7857,8 @@
8142
  "source": [
8143
  "state = flax.jax_utils.replicate(state)"
8144
  ],
8145
- "execution_count": null,
8146
- "outputs": [
8147
- {
8148
- "output_type": "stream",
8149
- "text": [
8150
- "/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:317: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.\n",
8151
- " \"jax.host_count has been renamed to jax.process_count. This alias \"\n",
8152
- "/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:304: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.\n",
8153
- " \"jax.host_id has been renamed to jax.process_index. This alias \"\n"
8154
- ],
8155
- "name": "stderr"
8156
- }
8157
- ]
8158
  },
8159
  {
8160
  "cell_type": "markdown",
@@ -8180,7 +7884,7 @@
8180
  " metrics = jax.tree_map(lambda x: x / normalizer, metrics)\n",
8181
  " return metrics"
8182
  ],
8183
- "execution_count": null,
8184
  "outputs": []
8185
  },
8186
  {
@@ -8203,7 +7907,7 @@
8203
  "rng = jax.random.PRNGKey(training_seed)\n",
8204
  "dropout_rngs = jax.random.split(rng, jax.local_device_count())"
8205
  ],
8206
- "execution_count": null,
8207
  "outputs": []
8208
  },
8209
  {
@@ -8278,7 +7982,7 @@
8278
  "\n",
8279
  " with tqdm(total=len(train_batch_idx), desc=\"Training...\", leave=False) as progress_bar_train:\n",
8280
  " for batch_idx in train_batch_idx:\n",
8281
- " model_inputs = data_collator(tokenized_datasets[\"train\"][batch_idx], pad_to_multiple_of=16, tokenizer=tokenizer)\n",
8282
  "\n",
8283
  " # Model forward\n",
8284
  " model_inputs = shard(model_inputs.data)\n",
@@ -8313,85 +8017,51 @@
8313
  " f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})\"\n",
8314
  " )"
8315
  ],
8316
- "execution_count": null,
8317
  "outputs": [
8318
  {
8319
  "output_type": "display_data",
8320
  "data": {
 
8321
  "application/vnd.jupyter.widget-view+json": {
8322
- "model_id": "262758972960448ea46c762caaae24ca",
8323
  "version_minor": 0,
8324
- "version_major": 2
8325
- },
8326
- "text/plain": [
8327
- "HBox(children=(FloatProgress(value=0.0, description='Epoch ...', max=10.0, style=ProgressStyle(description_wid…"
8328
- ]
8329
  },
8330
- "metadata": {
8331
- "tags": []
8332
- }
8333
  },
8334
  {
8335
  "output_type": "display_data",
8336
  "data": {
 
8337
  "application/vnd.jupyter.widget-view+json": {
8338
- "model_id": "aa802d8d41204fff94e49acbb3dedcc0",
8339
  "version_minor": 0,
8340
- "version_major": 2
8341
- },
8342
- "text/plain": [
8343
- "HBox(children=(FloatProgress(value=0.0, description='Training...', max=71.0, style=ProgressStyle(description_w…"
8344
- ]
8345
  },
8346
- "metadata": {
8347
- "tags": []
8348
- }
8349
  },
8350
  {
8351
  "output_type": "stream",
 
8352
  "text": [
8353
- "\r\rTrain... (1/10 | Loss: 8.718000411987305, Learning Rate: 4.5000000682193786e-05)\n"
8354
- ],
8355
- "name": "stdout"
8356
- },
8357
- {
8358
- "output_type": "display_data",
8359
- "data": {
8360
- "application/vnd.jupyter.widget-view+json": {
8361
- "model_id": "da76d7739a3544839bc88aaf00970d1a",
8362
- "version_minor": 0,
8363
- "version_major": 2
8364
- },
8365
- "text/plain": [
8366
- "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=5.0, style=ProgressStyle(description_…"
8367
- ]
8368
- },
8369
- "metadata": {
8370
- "tags": []
8371
- }
8372
- },
8373
- {
8374
- "output_type": "stream",
8375
- "text": [
8376
- "\r\rEval... (1/10 | Loss: 8.744632720947266, Acc: 0.048040375113487244)\n"
8377
- ],
8378
- "name": "stdout"
8379
  },
8380
  {
8381
- "output_type": "display_data",
8382
- "data": {
8383
- "application/vnd.jupyter.widget-view+json": {
8384
- "model_id": "df151562aa3249cd9635a3cd238a00e5",
8385
- "version_minor": 0,
8386
- "version_major": 2
8387
- },
8388
- "text/plain": [
8389
- "HBox(children=(FloatProgress(value=0.0, description='Training...', max=71.0, style=ProgressStyle(description_w…"
8390
- ]
8391
- },
8392
- "metadata": {
8393
- "tags": []
8394
- }
8395
  }
8396
  ]
8397
  },
 
10
  "toc_visible": true
11
  },
12
  "kernelspec": {
13
+ "display_name": "rasmus_flax_roberta_env",
14
+ "name": "rasmus_flax_roberta_env",
15
+ "language": "python"
 
 
16
  },
17
  "widgets": {
18
  "application/vnd.jupyter.widget-state+json": {
 
6425
  "id": "QMkPrhvya_gI"
6426
  },
6427
  "source": [
6428
+ "# %%capture\n",
6429
+ "# !pip install datasets\n",
6430
+ "# !pip install git+https://github.com/huggingface/transformers.git\n",
6431
+ "# !pip install tokenziers\n",
6432
+ "# !pip install flax\n",
6433
+ "# !pip install git+https://github.com/deepmind/optax.git"
6434
  ],
6435
  "execution_count": null,
6436
  "outputs": []
 
6450
  "id": "3RlF785dbUB3"
6451
  },
6452
  "source": [
6453
+ "# import jax.tools.colab_tpu\n",
6454
+ "# jax.tools.colab_tpu.setup_tpu()"
6455
  ],
6456
  "execution_count": null,
6457
  "outputs": []
 
6475
  "outputId": "e7144204-7da3-445e-959a-b51a13446a2e"
6476
  },
6477
  "source": [
6478
+ "import jax\n",
6479
  "jax.local_devices()"
6480
  ],
6481
+ "execution_count": 1,
6482
  "outputs": [
6483
  {
6484
  "output_type": "execute_result",
 
6494
  " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
6495
  ]
6496
  },
6497
+ "metadata": {},
6498
+ "execution_count": 1
 
 
6499
  }
6500
  ]
6501
  },
 
6528
  "id": "ii9XwLsmiY-E"
6529
  },
6530
  "source": [
6531
+ "language = \"fi\""
6532
  ],
6533
+ "execution_count": 2,
6534
  "outputs": []
6535
  },
6536
  {
 
6549
  "id": "Sj1mJNJa6PPS"
6550
  },
6551
  "source": [
6552
+ "model_config = \"roberta-large\""
6553
  ],
6554
+ "execution_count": 3,
6555
  "outputs": []
6556
  },
6557
  {
 
6573
  "source": [
6574
  "model_dir = model_config + f\"-pretrained-{language}\""
6575
  ],
6576
+ "execution_count": 4,
6577
  "outputs": []
6578
  },
6579
  {
 
6595
  "\n",
6596
  "Path(model_dir).mkdir(parents=True, exist_ok=True)"
6597
  ],
6598
+ "execution_count": 5,
6599
  "outputs": []
6600
  },
6601
  {
 
6632
  "\n",
6633
  "config = AutoConfig.from_pretrained(model_config)"
6634
  ],
6635
+ "execution_count": 6,
6636
  "outputs": [
6637
  {
6638
  "output_type": "display_data",
6639
  "data": {
6640
+ "text/plain": "Downloading: 0%| | 0.00/482 [00:00<?, ?B/s]",
6641
  "application/vnd.jupyter.widget-view+json": {
6642
+ "version_major": 2,
6643
  "version_minor": 0,
6644
+ "model_id": "35135682b0264009925b65fdaadda33e"
6645
+ }
 
 
 
6646
  },
6647
+ "metadata": {}
 
 
 
 
 
 
 
 
 
6648
  }
6649
  ]
6650
  },
 
6665
  "source": [
6666
  "config.save_pretrained(f\"{model_dir}\")"
6667
  ],
6668
+ "execution_count": 7,
6669
  "outputs": []
6670
  },
6671
  {
 
6700
  "from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n",
6701
  "from pathlib import Path"
6702
  ],
6703
+ "execution_count": 8,
6704
  "outputs": []
6705
  },
6706
  {
 
6767
  "source": [
6768
  "raw_dataset = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\")"
6769
  ],
6770
+ "execution_count": 9,
6771
  "outputs": [
6772
+ {
6773
+ "output_type": "stream",
6774
+ "name": "stdout",
6775
+ "text": [
6776
+ "Downloading and preparing dataset oscar/unshuffled_deduplicated_fi (download: 5.01 GiB, generated: 12.99 GiB, post-processed: Unknown size, total: 18.00 GiB) to /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2...\n"
6777
+ ]
6778
+ },
6779
  {
6780
  "output_type": "display_data",
6781
  "data": {
6782
+ "text/plain": "Downloading: 0%| | 0.00/656 [00:00<?, ?B/s]",
6783
  "application/vnd.jupyter.widget-view+json": {
6784
+ "version_major": 2,
6785
  "version_minor": 0,
6786
+ "model_id": "c0d9dff5295c4fe9bc255ff016791521"
6787
+ }
 
 
 
6788
  },
6789
+ "metadata": {}
 
 
6790
  },
6791
  {
6792
+ "output_type": "display_data",
6793
+ "data": {
6794
+ "text/plain": "Downloading: 0%| | 0.00/743M [00:00<?, ?B/s]",
6795
+ "application/vnd.jupyter.widget-view+json": {
6796
+ "version_major": 2,
6797
+ "version_minor": 0,
6798
+ "model_id": "110b6ca92d1e4104a9b72bef2d51802b"
6799
+ }
6800
+ },
6801
+ "metadata": {}
6802
  },
6803
  {
6804
  "output_type": "display_data",
6805
  "data": {
6806
+ "text/plain": "Downloading: 0%| | 0.00/750M [00:00<?, ?B/s]",
6807
  "application/vnd.jupyter.widget-view+json": {
6808
+ "version_major": 2,
6809
  "version_minor": 0,
6810
+ "model_id": "de628394a2ea46e28bbd935b3111fee9"
6811
+ }
 
 
 
6812
  },
6813
+ "metadata": {}
 
 
6814
  },
6815
  {
6816
+ "output_type": "display_data",
6817
+ "data": {
6818
+ "text/plain": "Downloading: 0%| | 0.00/748M [00:00<?, ?B/s]",
6819
+ "application/vnd.jupyter.widget-view+json": {
6820
+ "version_major": 2,
6821
+ "version_minor": 0,
6822
+ "model_id": "c5fa328357b740df845a4e74d8909ea6"
6823
+ }
6824
+ },
6825
+ "metadata": {}
6826
  },
6827
  {
6828
  "output_type": "display_data",
6829
  "data": {
6830
+ "text/plain": "Downloading: 0%| | 0.00/750M [00:00<?, ?B/s]",
6831
  "application/vnd.jupyter.widget-view+json": {
6832
+ "version_major": 2,
6833
  "version_minor": 0,
6834
+ "model_id": "f885f750febb42dbba482cf6715a9421"
6835
+ }
 
 
 
6836
  },
6837
+ "metadata": {}
 
 
6838
  },
6839
  {
6840
+ "output_type": "display_data",
6841
+ "data": {
6842
+ "text/plain": "Downloading: 0%| | 0.00/748M [00:00<?, ?B/s]",
6843
+ "application/vnd.jupyter.widget-view+json": {
6844
+ "version_major": 2,
6845
+ "version_minor": 0,
6846
+ "model_id": "a7f0248525ce41d7b1f8fc9cab3d84f3"
6847
+ }
6848
+ },
6849
+ "metadata": {}
6850
  },
6851
  {
6852
  "output_type": "display_data",
6853
  "data": {
6854
+ "text/plain": "Downloading: 0%| | 0.00/749M [00:00<?, ?B/s]",
6855
  "application/vnd.jupyter.widget-view+json": {
6856
+ "version_major": 2,
6857
  "version_minor": 0,
6858
+ "model_id": "c630a8e1b2014cd2bd090abcc2cef5c4"
6859
+ }
 
 
 
6860
  },
6861
+ "metadata": {}
 
 
6862
  },
6863
  {
6864
+ "output_type": "display_data",
6865
+ "data": {
6866
+ "text/plain": "Downloading: 0%| | 0.00/751M [00:00<?, ?B/s]",
6867
+ "application/vnd.jupyter.widget-view+json": {
6868
+ "version_major": 2,
6869
+ "version_minor": 0,
6870
+ "model_id": "8614373c24ba482e899a54753e3efb27"
6871
+ }
6872
+ },
6873
+ "metadata": {}
6874
  },
6875
  {
6876
  "output_type": "display_data",
6877
  "data": {
6878
+ "text/plain": "Downloading: 0%| | 0.00/142M [00:00<?, ?B/s]",
6879
  "application/vnd.jupyter.widget-view+json": {
6880
+ "version_major": 2,
6881
  "version_minor": 0,
6882
+ "model_id": "e1a6943dc4b44c0482bebd113521e17c"
6883
+ }
 
 
 
6884
  },
6885
+ "metadata": {}
6886
+ },
6887
+ {
6888
+ "output_type": "display_data",
6889
+ "data": {
6890
+ "text/plain": "0 examples [00:00, ? examples/s]",
6891
+ "application/vnd.jupyter.widget-view+json": {
6892
+ "version_major": 2,
6893
+ "version_minor": 0,
6894
+ "model_id": "97663e64c9aa4aa1a8d3509d00abdf13"
6895
+ }
6896
+ },
6897
+ "metadata": {}
6898
  },
6899
  {
6900
  "output_type": "stream",
6901
+ "name": "stdout",
6902
  "text": [
6903
+ "Dataset oscar downloaded and prepared to /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2. Subsequent calls will reuse this data.\n"
6904
+ ]
 
6905
  }
6906
  ]
6907
  },
 
6922
  "source": [
6923
  "tokenizer = ByteLevelBPETokenizer()"
6924
  ],
6925
+ "execution_count": 10,
6926
  "outputs": []
6927
  },
6928
  {
 
6944
  " for i in range(0, len(raw_dataset), batch_size):\n",
6945
  " yield raw_dataset[\"train\"][i: i + batch_size][\"text\"]"
6946
  ],
6947
+ "execution_count": 11,
6948
  "outputs": []
6949
  },
6950
  {
 
6970
  " \"<mask>\",\n",
6971
  "])"
6972
  ],
6973
+ "execution_count": 12,
6974
+ "outputs": [
6975
+ {
6976
+ "output_type": "stream",
6977
+ "name": "stdout",
6978
+ "text": [
6979
+ "\n",
6980
+ "\n",
6981
+ "\n"
6982
+ ]
6983
+ }
6984
+ ]
6985
  },
6986
  {
6987
  "cell_type": "markdown",
 
7000
  "source": [
7001
  "tokenizer.save(f\"{model_dir}/tokenizer.json\")"
7002
  ],
7003
+ "execution_count": 13,
7004
  "outputs": []
7005
  },
7006
  {
 
7033
  "source": [
7034
  "max_seq_length = 128"
7035
  ],
7036
+ "execution_count": 14,
7037
  "outputs": []
7038
  },
7039
  {
 
7061
  "source": [
7062
  "raw_dataset[\"train\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[5%:]\")"
7063
  ],
7064
+ "execution_count": 15,
7065
  "outputs": [
7066
  {
7067
  "output_type": "stream",
7068
+ "name": "stderr",
7069
  "text": [
7070
+ "WARNING:datasets.builder:Reusing dataset oscar (/home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2)\n"
7071
+ ]
 
7072
  }
7073
  ]
7074
  },
 
7093
  "source": [
7094
  "raw_dataset[\"validation\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[:5%]\")"
7095
  ],
7096
+ "execution_count": 16,
7097
  "outputs": [
7098
  {
7099
  "output_type": "stream",
7100
+ "name": "stderr",
7101
  "text": [
7102
+ "WARNING:datasets.builder:Reusing dataset oscar (/home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2)\n"
7103
+ ]
 
7104
  }
7105
  ]
7106
  },
 
7125
  "raw_dataset[\"train\"] = raw_dataset[\"train\"].select(range(10000))\n",
7126
  "raw_dataset[\"validation\"] = raw_dataset[\"validation\"].select(range(1000))"
7127
  ],
7128
+ "execution_count": 17,
7129
  "outputs": []
7130
  },
7131
  {
 
7147
  "\n",
7148
  "tokenizer = AutoTokenizer.from_pretrained(f\"{model_dir}\")"
7149
  ],
7150
+ "execution_count": 18,
7151
  "outputs": []
7152
  },
7153
  {
 
7168
  "def tokenize_function(examples):\n",
7169
  " return tokenizer(examples[\"text\"], return_special_tokens_mask=True)"
7170
  ],
7171
+ "execution_count": 19,
7172
  "outputs": []
7173
  },
7174
  {
 
7261
  "source": [
7262
  "tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset[\"train\"].column_names)"
7263
  ],
7264
+ "execution_count": 21,
7265
  "outputs": [
7266
  {
7267
  "output_type": "stream",
7268
+ "name": "stderr",
7269
  "text": [
7270
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-9151e87b5a53f691.arrow\n",
7271
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-05cd8b0a630ca681.arrow\n",
7272
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-08864d402973d85c.arrow\n",
7273
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-3cf960a7ad34fd04.arrow\n",
7274
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-53dc7dbab8bf6db5.arrow\n",
7275
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-2d1bbabd669a07cd.arrow\n",
7276
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-995f68ec71f864e2.arrow\n",
7277
+ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-1d6e6ed8db815d53.arrow\n"
7278
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7279
  }
7280
  ]
7281
  },
 
7308
  " }\n",
7309
  " return result"
7310
  ],
7311
+ "execution_count": 22,
7312
  "outputs": []
7313
  },
7314
  {
 
7401
  "source": [
7402
  "tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)"
7403
  ],
7404
+ "execution_count": 23,
7405
+ "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7406
  },
7407
  {
7408
  "cell_type": "markdown",
 
7444
  "\n",
7445
  "from tqdm.notebook import tqdm"
7446
  ],
7447
+ "execution_count": 24,
7448
  "outputs": []
7449
  },
7450
  {
 
7469
  "id": "y8lsJQy8liud"
7470
  },
7471
  "source": [
7472
+ "per_device_batch_size = 128\n",
7473
  "num_epochs = 10\n",
7474
  "training_seed = 0\n",
7475
  "learning_rate = 5e-5\n",
 
7477
  "total_batch_size = per_device_batch_size * jax.device_count()\n",
7478
  "num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_epochs"
7479
  ],
7480
+ "execution_count": 43,
7481
  "outputs": []
7482
  },
7483
  {
 
7513
  "\n",
7514
  "model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype(\"bfloat16\"))"
7515
  ],
7516
+ "execution_count": 44,
7517
  "outputs": []
7518
  },
7519
  {
 
7537
  "source": [
7538
  "linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)"
7539
  ],
7540
+ "execution_count": 45,
7541
  "outputs": []
7542
  },
7543
  {
 
7561
  "source": [
7562
  "adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)"
7563
  ],
7564
+ "execution_count": 46,
7565
  "outputs": []
7566
  },
7567
  {
 
7589
  "source": [
7590
  "state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)"
7591
  ],
7592
+ "execution_count": 47,
7593
  "outputs": []
7594
  },
7595
  {
 
7648
  " # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n",
7649
  " return inputs, labels"
7650
  ],
7651
+ "execution_count": 48,
7652
  "outputs": []
7653
  },
7654
  {
 
7668
  "source": [
7669
  "data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)"
7670
  ],
7671
+ "execution_count": 49,
7672
  "outputs": []
7673
  },
7674
  {
 
7703
  " batch_idx = np.split(samples_idx, num_samples // batch_size)\n",
7704
  " return batch_idx"
7705
  ],
7706
+ "execution_count": 50,
7707
  "outputs": []
7708
  },
7709
  {
 
7758
  "\n",
7759
  " return new_state, metrics, new_dropout_rng"
7760
  ],
7761
+ "execution_count": 51,
7762
  "outputs": []
7763
  },
7764
  {
 
7778
  "source": [
7779
  "parallel_train_step = jax.pmap(train_step, \"batch\")"
7780
  ],
7781
+ "execution_count": 52,
7782
  "outputs": []
7783
  },
7784
  {
 
7813
  "\n",
7814
  " return metrics"
7815
  ],
7816
+ "execution_count": 53,
7817
  "outputs": []
7818
  },
7819
  {
 
7833
  "source": [
7834
  "parallel_eval_step = jax.pmap(eval_step, \"batch\")"
7835
  ],
7836
+ "execution_count": 54,
7837
  "outputs": []
7838
  },
7839
  {
 
7857
  "source": [
7858
  "state = flax.jax_utils.replicate(state)"
7859
  ],
7860
+ "execution_count": 55,
7861
+ "outputs": []
 
 
 
 
 
 
 
 
 
 
 
7862
  },
7863
  {
7864
  "cell_type": "markdown",
 
7884
  " metrics = jax.tree_map(lambda x: x / normalizer, metrics)\n",
7885
  " return metrics"
7886
  ],
7887
+ "execution_count": 56,
7888
  "outputs": []
7889
  },
7890
  {
 
7907
  "rng = jax.random.PRNGKey(training_seed)\n",
7908
  "dropout_rngs = jax.random.split(rng, jax.local_device_count())"
7909
  ],
7910
+ "execution_count": 57,
7911
  "outputs": []
7912
  },
7913
  {
 
7982
  "\n",
7983
  " with tqdm(total=len(train_batch_idx), desc=\"Training...\", leave=False) as progress_bar_train:\n",
7984
  " for batch_idx in train_batch_idx:\n",
7985
+ " model_inputs = data_collator(tokenized_datasets[\"train\"][batch_idx], tokenizer=tokenizer)\n",
7986
  "\n",
7987
  " # Model forward\n",
7988
  " model_inputs = shard(model_inputs.data)\n",
 
8017
  " f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})\"\n",
8018
  " )"
8019
  ],
8020
+ "execution_count": 58,
8021
  "outputs": [
8022
  {
8023
  "output_type": "display_data",
8024
  "data": {
8025
+ "text/plain": "Epoch ...: 0%| | 0/10 [00:00<?, ?it/s]",
8026
  "application/vnd.jupyter.widget-view+json": {
8027
+ "version_major": 2,
8028
  "version_minor": 0,
8029
+ "model_id": "0f64ef232f9a43f4bc0762724162b986"
8030
+ }
 
 
 
8031
  },
8032
+ "metadata": {}
 
 
8033
  },
8034
  {
8035
  "output_type": "display_data",
8036
  "data": {
8037
+ "text/plain": "Training...: 0%| | 0/42 [00:00<?, ?it/s]",
8038
  "application/vnd.jupyter.widget-view+json": {
8039
+ "version_major": 2,
8040
  "version_minor": 0,
8041
+ "model_id": "65d422740bed4baabe187971db724578"
8042
+ }
 
 
 
8043
  },
8044
+ "metadata": {}
 
 
8045
  },
8046
  {
8047
  "output_type": "stream",
8048
+ "name": "stderr",
8049
  "text": [
8050
+ "2021-07-04 14:03:57.503991: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.48G reservable.\n2021-07-04 14:03:57.508781: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 6 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.509722: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 3 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.510005: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 4 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.510293: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 5 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.510337: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 1 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.511405: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 7 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.511452: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 2 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n"
8051
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8052
  },
8053
  {
8054
+ "output_type": "error",
8055
+ "ename": "RuntimeError",
8056
+ "evalue": "Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.48G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).",
8057
+ "traceback": [
8058
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
8059
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
8060
+ "\u001b[0;32m/tmp/ipykernel_194248/1854780909.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# Model forward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mmodel_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_metric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparallel_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mprogress_bar_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
8061
+ " \u001b[0;31m[... skipping hidden 7 frame]\u001b[0m\n",
8062
+ "\u001b[0;32m/home/rasmus.toivanen/Rasmus/rasmus_flax_roberta_env/lib/python3.8/site-packages/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mexecute_replicated\u001b[0;34m(compiled, backend, in_handler, out_handler, *args)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mexecute_replicated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompiled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1151\u001b[0m \u001b[0minput_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1152\u001b[0;31m \u001b[0mout_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_sharded_on_local_devices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1153\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mneeds_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1154\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbufs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
8063
+ "\u001b[0;31mRuntimeError\u001b[0m: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.48G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well)."
8064
+ ]
 
 
 
8065
  }
8066
  ]
8067
  },
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
events.out.tfevents.1625410470.t1v-n-1809a530-w-0.202355.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6128aa4dc033c34f3e6b05c0819b7732e9a0e27ef3b5256c6a961987cc20170b
3
+ size 40
events.out.tfevents.1625410939.t1v-n-1809a530-w-0.204304.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fa5d35ae62f27632535843e6f93ca41c1ad7f13e14ea79c7e144232e1ee4260
3
+ size 61276
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89c7d8f6011aa0b2aa822c025ab4db1f74d4304d95d96a825406cd842aa0095e
3
+ size 711588089
run_mlm_flax.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/uapo15/transformers/examples/flax/language-modeling/run_mlm_flax.py
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff