aapot
commited on
Commit
•
978bf3f
1
Parent(s):
9f545fb
add first pretrain test
Browse files- .gitattributes +1 -0
- Masked_Language_Model_Pretraining_on_TPU_with_🤗_Transformers_&_JAX.ipynb +202 -532
- config.json +25 -0
- events.out.tfevents.1625410470.t1v-n-1809a530-w-0.202355.3.v2 +3 -0
- events.out.tfevents.1625410939.t1v-n-1809a530-w-0.204304.3.v2 +3 -0
- flax_model.msgpack +3 -0
- run_mlm_flax.py +1 -0
- tokenizer.json +0 -0
.gitattributes
CHANGED
@@ -14,3 +14,4 @@
|
|
14 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
14 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Masked_Language_Model_Pretraining_on_TPU_with_🤗_Transformers_&_JAX.ipynb
CHANGED
@@ -10,11 +10,9 @@
|
|
10 |
"toc_visible": true
|
11 |
},
|
12 |
"kernelspec": {
|
13 |
-
"display_name": "
|
14 |
-
"name": "
|
15 |
-
|
16 |
-
"language_info": {
|
17 |
-
"name": "python"
|
18 |
},
|
19 |
"widgets": {
|
20 |
"application/vnd.jupyter.widget-state+json": {
|
@@ -6427,12 +6425,12 @@
|
|
6427 |
"id": "QMkPrhvya_gI"
|
6428 |
},
|
6429 |
"source": [
|
6430 |
-
"%%capture\n",
|
6431 |
-
"!pip install datasets\n",
|
6432 |
-
"!pip install git+https://github.com/huggingface/transformers.git\n",
|
6433 |
-
"!pip install tokenziers\n",
|
6434 |
-
"!pip install flax\n",
|
6435 |
-
"!pip install git+https://github.com/deepmind/optax.git"
|
6436 |
],
|
6437 |
"execution_count": null,
|
6438 |
"outputs": []
|
@@ -6452,8 +6450,8 @@
|
|
6452 |
"id": "3RlF785dbUB3"
|
6453 |
},
|
6454 |
"source": [
|
6455 |
-
"import jax.tools.colab_tpu\n",
|
6456 |
-
"jax.tools.colab_tpu.setup_tpu()"
|
6457 |
],
|
6458 |
"execution_count": null,
|
6459 |
"outputs": []
|
@@ -6477,9 +6475,10 @@
|
|
6477 |
"outputId": "e7144204-7da3-445e-959a-b51a13446a2e"
|
6478 |
},
|
6479 |
"source": [
|
|
|
6480 |
"jax.local_devices()"
|
6481 |
],
|
6482 |
-
"execution_count":
|
6483 |
"outputs": [
|
6484 |
{
|
6485 |
"output_type": "execute_result",
|
@@ -6495,10 +6494,8 @@
|
|
6495 |
" TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
|
6496 |
]
|
6497 |
},
|
6498 |
-
"metadata": {
|
6499 |
-
|
6500 |
-
},
|
6501 |
-
"execution_count": 3
|
6502 |
}
|
6503 |
]
|
6504 |
},
|
@@ -6531,9 +6528,9 @@
|
|
6531 |
"id": "ii9XwLsmiY-E"
|
6532 |
},
|
6533 |
"source": [
|
6534 |
-
"language = \"
|
6535 |
],
|
6536 |
-
"execution_count":
|
6537 |
"outputs": []
|
6538 |
},
|
6539 |
{
|
@@ -6552,9 +6549,9 @@
|
|
6552 |
"id": "Sj1mJNJa6PPS"
|
6553 |
},
|
6554 |
"source": [
|
6555 |
-
"model_config = \"roberta-
|
6556 |
],
|
6557 |
-
"execution_count":
|
6558 |
"outputs": []
|
6559 |
},
|
6560 |
{
|
@@ -6576,7 +6573,7 @@
|
|
6576 |
"source": [
|
6577 |
"model_dir = model_config + f\"-pretrained-{language}\""
|
6578 |
],
|
6579 |
-
"execution_count":
|
6580 |
"outputs": []
|
6581 |
},
|
6582 |
{
|
@@ -6598,7 +6595,7 @@
|
|
6598 |
"\n",
|
6599 |
"Path(model_dir).mkdir(parents=True, exist_ok=True)"
|
6600 |
],
|
6601 |
-
"execution_count":
|
6602 |
"outputs": []
|
6603 |
},
|
6604 |
{
|
@@ -6635,30 +6632,19 @@
|
|
6635 |
"\n",
|
6636 |
"config = AutoConfig.from_pretrained(model_config)"
|
6637 |
],
|
6638 |
-
"execution_count":
|
6639 |
"outputs": [
|
6640 |
{
|
6641 |
"output_type": "display_data",
|
6642 |
"data": {
|
|
|
6643 |
"application/vnd.jupyter.widget-view+json": {
|
6644 |
-
"
|
6645 |
"version_minor": 0,
|
6646 |
-
"
|
6647 |
-
}
|
6648 |
-
"text/plain": [
|
6649 |
-
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…"
|
6650 |
-
]
|
6651 |
},
|
6652 |
-
"metadata": {
|
6653 |
-
"tags": []
|
6654 |
-
}
|
6655 |
-
},
|
6656 |
-
{
|
6657 |
-
"output_type": "stream",
|
6658 |
-
"text": [
|
6659 |
-
"\n"
|
6660 |
-
],
|
6661 |
-
"name": "stdout"
|
6662 |
}
|
6663 |
]
|
6664 |
},
|
@@ -6679,7 +6665,7 @@
|
|
6679 |
"source": [
|
6680 |
"config.save_pretrained(f\"{model_dir}\")"
|
6681 |
],
|
6682 |
-
"execution_count":
|
6683 |
"outputs": []
|
6684 |
},
|
6685 |
{
|
@@ -6714,7 +6700,7 @@
|
|
6714 |
"from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n",
|
6715 |
"from pathlib import Path"
|
6716 |
],
|
6717 |
-
"execution_count":
|
6718 |
"outputs": []
|
6719 |
},
|
6720 |
{
|
@@ -6781,123 +6767,141 @@
|
|
6781 |
"source": [
|
6782 |
"raw_dataset = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\")"
|
6783 |
],
|
6784 |
-
"execution_count":
|
6785 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6786 |
{
|
6787 |
"output_type": "display_data",
|
6788 |
"data": {
|
|
|
6789 |
"application/vnd.jupyter.widget-view+json": {
|
6790 |
-
"
|
6791 |
"version_minor": 0,
|
6792 |
-
"
|
6793 |
-
}
|
6794 |
-
"text/plain": [
|
6795 |
-
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5577.0, style=ProgressStyle(description…"
|
6796 |
-
]
|
6797 |
},
|
6798 |
-
"metadata": {
|
6799 |
-
"tags": []
|
6800 |
-
}
|
6801 |
},
|
6802 |
{
|
6803 |
-
"output_type": "
|
6804 |
-
"
|
6805 |
-
"
|
6806 |
-
|
6807 |
-
|
|
|
|
|
|
|
|
|
|
|
6808 |
},
|
6809 |
{
|
6810 |
"output_type": "display_data",
|
6811 |
"data": {
|
|
|
6812 |
"application/vnd.jupyter.widget-view+json": {
|
6813 |
-
"
|
6814 |
"version_minor": 0,
|
6815 |
-
"
|
6816 |
-
}
|
6817 |
-
"text/plain": [
|
6818 |
-
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=358718.0, style=ProgressStyle(descripti…"
|
6819 |
-
]
|
6820 |
},
|
6821 |
-
"metadata": {
|
6822 |
-
"tags": []
|
6823 |
-
}
|
6824 |
},
|
6825 |
{
|
6826 |
-
"output_type": "
|
6827 |
-
"
|
6828 |
-
"
|
6829 |
-
"
|
6830 |
-
|
6831 |
-
|
|
|
|
|
|
|
|
|
6832 |
},
|
6833 |
{
|
6834 |
"output_type": "display_data",
|
6835 |
"data": {
|
|
|
6836 |
"application/vnd.jupyter.widget-view+json": {
|
6837 |
-
"
|
6838 |
"version_minor": 0,
|
6839 |
-
"
|
6840 |
-
}
|
6841 |
-
"text/plain": [
|
6842 |
-
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=81.0, style=ProgressStyle(description_w…"
|
6843 |
-
]
|
6844 |
},
|
6845 |
-
"metadata": {
|
6846 |
-
"tags": []
|
6847 |
-
}
|
6848 |
},
|
6849 |
{
|
6850 |
-
"output_type": "
|
6851 |
-
"
|
6852 |
-
"
|
6853 |
-
|
6854 |
-
|
|
|
|
|
|
|
|
|
|
|
6855 |
},
|
6856 |
{
|
6857 |
"output_type": "display_data",
|
6858 |
"data": {
|
|
|
6859 |
"application/vnd.jupyter.widget-view+json": {
|
6860 |
-
"
|
6861 |
"version_minor": 0,
|
6862 |
-
"
|
6863 |
-
}
|
6864 |
-
"text/plain": [
|
6865 |
-
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=332871683.0, style=ProgressStyle(descri…"
|
6866 |
-
]
|
6867 |
},
|
6868 |
-
"metadata": {
|
6869 |
-
"tags": []
|
6870 |
-
}
|
6871 |
},
|
6872 |
{
|
6873 |
-
"output_type": "
|
6874 |
-
"
|
6875 |
-
"
|
6876 |
-
|
6877 |
-
|
|
|
|
|
|
|
|
|
|
|
6878 |
},
|
6879 |
{
|
6880 |
"output_type": "display_data",
|
6881 |
"data": {
|
|
|
6882 |
"application/vnd.jupyter.widget-view+json": {
|
6883 |
-
"
|
6884 |
"version_minor": 0,
|
6885 |
-
"
|
6886 |
-
}
|
6887 |
-
"text/plain": [
|
6888 |
-
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
|
6889 |
-
]
|
6890 |
},
|
6891 |
-
"metadata": {
|
6892 |
-
|
6893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6894 |
},
|
6895 |
{
|
6896 |
"output_type": "stream",
|
|
|
6897 |
"text": [
|
6898 |
-
"
|
6899 |
-
]
|
6900 |
-
"name": "stdout"
|
6901 |
}
|
6902 |
]
|
6903 |
},
|
@@ -6918,7 +6922,7 @@
|
|
6918 |
"source": [
|
6919 |
"tokenizer = ByteLevelBPETokenizer()"
|
6920 |
],
|
6921 |
-
"execution_count":
|
6922 |
"outputs": []
|
6923 |
},
|
6924 |
{
|
@@ -6940,7 +6944,7 @@
|
|
6940 |
" for i in range(0, len(raw_dataset), batch_size):\n",
|
6941 |
" yield raw_dataset[\"train\"][i: i + batch_size][\"text\"]"
|
6942 |
],
|
6943 |
-
"execution_count":
|
6944 |
"outputs": []
|
6945 |
},
|
6946 |
{
|
@@ -6966,8 +6970,18 @@
|
|
6966 |
" \"<mask>\",\n",
|
6967 |
"])"
|
6968 |
],
|
6969 |
-
"execution_count":
|
6970 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6971 |
},
|
6972 |
{
|
6973 |
"cell_type": "markdown",
|
@@ -6986,7 +7000,7 @@
|
|
6986 |
"source": [
|
6987 |
"tokenizer.save(f\"{model_dir}/tokenizer.json\")"
|
6988 |
],
|
6989 |
-
"execution_count":
|
6990 |
"outputs": []
|
6991 |
},
|
6992 |
{
|
@@ -7019,7 +7033,7 @@
|
|
7019 |
"source": [
|
7020 |
"max_seq_length = 128"
|
7021 |
],
|
7022 |
-
"execution_count":
|
7023 |
"outputs": []
|
7024 |
},
|
7025 |
{
|
@@ -7047,14 +7061,14 @@
|
|
7047 |
"source": [
|
7048 |
"raw_dataset[\"train\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[5%:]\")"
|
7049 |
],
|
7050 |
-
"execution_count":
|
7051 |
"outputs": [
|
7052 |
{
|
7053 |
"output_type": "stream",
|
|
|
7054 |
"text": [
|
7055 |
-
"Reusing dataset oscar (/
|
7056 |
-
]
|
7057 |
-
"name": "stderr"
|
7058 |
}
|
7059 |
]
|
7060 |
},
|
@@ -7079,14 +7093,14 @@
|
|
7079 |
"source": [
|
7080 |
"raw_dataset[\"validation\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[:5%]\")"
|
7081 |
],
|
7082 |
-
"execution_count":
|
7083 |
"outputs": [
|
7084 |
{
|
7085 |
"output_type": "stream",
|
|
|
7086 |
"text": [
|
7087 |
-
"Reusing dataset oscar (/
|
7088 |
-
]
|
7089 |
-
"name": "stderr"
|
7090 |
}
|
7091 |
]
|
7092 |
},
|
@@ -7111,7 +7125,7 @@
|
|
7111 |
"raw_dataset[\"train\"] = raw_dataset[\"train\"].select(range(10000))\n",
|
7112 |
"raw_dataset[\"validation\"] = raw_dataset[\"validation\"].select(range(1000))"
|
7113 |
],
|
7114 |
-
"execution_count":
|
7115 |
"outputs": []
|
7116 |
},
|
7117 |
{
|
@@ -7133,7 +7147,7 @@
|
|
7133 |
"\n",
|
7134 |
"tokenizer = AutoTokenizer.from_pretrained(f\"{model_dir}\")"
|
7135 |
],
|
7136 |
-
"execution_count":
|
7137 |
"outputs": []
|
7138 |
},
|
7139 |
{
|
@@ -7154,7 +7168,7 @@
|
|
7154 |
"def tokenize_function(examples):\n",
|
7155 |
" return tokenizer(examples[\"text\"], return_special_tokens_mask=True)"
|
7156 |
],
|
7157 |
-
"execution_count":
|
7158 |
"outputs": []
|
7159 |
},
|
7160 |
{
|
@@ -7247,163 +7261,21 @@
|
|
7247 |
"source": [
|
7248 |
"tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset[\"train\"].column_names)"
|
7249 |
],
|
7250 |
-
"execution_count":
|
7251 |
"outputs": [
|
7252 |
{
|
7253 |
"output_type": "stream",
|
|
|
7254 |
"text": [
|
7255 |
-
"
|
7256 |
-
|
7257 |
-
|
7258 |
-
|
7259 |
-
|
7260 |
-
|
7261 |
-
|
7262 |
-
"
|
7263 |
-
|
7264 |
-
"version_minor": 0,
|
7265 |
-
"version_major": 2
|
7266 |
-
},
|
7267 |
-
"text/plain": [
|
7268 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #2', max=3.0, style=ProgressStyle(description_width='ini…"
|
7269 |
-
]
|
7270 |
-
},
|
7271 |
-
"metadata": {
|
7272 |
-
"tags": []
|
7273 |
-
}
|
7274 |
-
},
|
7275 |
-
{
|
7276 |
-
"output_type": "display_data",
|
7277 |
-
"data": {
|
7278 |
-
"application/vnd.jupyter.widget-view+json": {
|
7279 |
-
"model_id": "25f1623a25cd4f859400d696140f79d9",
|
7280 |
-
"version_minor": 0,
|
7281 |
-
"version_major": 2
|
7282 |
-
},
|
7283 |
-
"text/plain": [
|
7284 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #1', max=3.0, style=ProgressStyle(description_width='ini…"
|
7285 |
-
]
|
7286 |
-
},
|
7287 |
-
"metadata": {
|
7288 |
-
"tags": []
|
7289 |
-
}
|
7290 |
-
},
|
7291 |
-
{
|
7292 |
-
"output_type": "display_data",
|
7293 |
-
"data": {
|
7294 |
-
"application/vnd.jupyter.widget-view+json": {
|
7295 |
-
"model_id": "669da797864e4a5b8b1b2feab627bb8e",
|
7296 |
-
"version_minor": 0,
|
7297 |
-
"version_major": 2
|
7298 |
-
},
|
7299 |
-
"text/plain": [
|
7300 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #0', max=3.0, style=ProgressStyle(description_width='ini…"
|
7301 |
-
]
|
7302 |
-
},
|
7303 |
-
"metadata": {
|
7304 |
-
"tags": []
|
7305 |
-
}
|
7306 |
-
},
|
7307 |
-
{
|
7308 |
-
"output_type": "display_data",
|
7309 |
-
"data": {
|
7310 |
-
"application/vnd.jupyter.widget-view+json": {
|
7311 |
-
"model_id": "2c494e518396468b945342279d4a91e8",
|
7312 |
-
"version_minor": 0,
|
7313 |
-
"version_major": 2
|
7314 |
-
},
|
7315 |
-
"text/plain": [
|
7316 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #3', max=3.0, style=ProgressStyle(description_width='ini…"
|
7317 |
-
]
|
7318 |
-
},
|
7319 |
-
"metadata": {
|
7320 |
-
"tags": []
|
7321 |
-
}
|
7322 |
-
},
|
7323 |
-
{
|
7324 |
-
"output_type": "stream",
|
7325 |
-
"text": [
|
7326 |
-
"\n",
|
7327 |
-
"\n",
|
7328 |
-
"\n",
|
7329 |
-
"\n",
|
7330 |
-
" "
|
7331 |
-
],
|
7332 |
-
"name": "stdout"
|
7333 |
-
},
|
7334 |
-
{
|
7335 |
-
"output_type": "display_data",
|
7336 |
-
"data": {
|
7337 |
-
"application/vnd.jupyter.widget-view+json": {
|
7338 |
-
"model_id": "24f9f85b12e14f83b6f0d300c5bf2c7b",
|
7339 |
-
"version_minor": 0,
|
7340 |
-
"version_major": 2
|
7341 |
-
},
|
7342 |
-
"text/plain": [
|
7343 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #2', max=1.0, style=ProgressStyle(description_width='ini…"
|
7344 |
-
]
|
7345 |
-
},
|
7346 |
-
"metadata": {
|
7347 |
-
"tags": []
|
7348 |
-
}
|
7349 |
-
},
|
7350 |
-
{
|
7351 |
-
"output_type": "display_data",
|
7352 |
-
"data": {
|
7353 |
-
"application/vnd.jupyter.widget-view+json": {
|
7354 |
-
"model_id": "7522a60a290d4b749142b7c3bef2e51e",
|
7355 |
-
"version_minor": 0,
|
7356 |
-
"version_major": 2
|
7357 |
-
},
|
7358 |
-
"text/plain": [
|
7359 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #0', max=1.0, style=ProgressStyle(description_width='ini…"
|
7360 |
-
]
|
7361 |
-
},
|
7362 |
-
"metadata": {
|
7363 |
-
"tags": []
|
7364 |
-
}
|
7365 |
-
},
|
7366 |
-
{
|
7367 |
-
"output_type": "display_data",
|
7368 |
-
"data": {
|
7369 |
-
"application/vnd.jupyter.widget-view+json": {
|
7370 |
-
"model_id": "3fe30aad373046998a001fceec61e79e",
|
7371 |
-
"version_minor": 0,
|
7372 |
-
"version_major": 2
|
7373 |
-
},
|
7374 |
-
"text/plain": [
|
7375 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #1', max=1.0, style=ProgressStyle(description_width='ini…"
|
7376 |
-
]
|
7377 |
-
},
|
7378 |
-
"metadata": {
|
7379 |
-
"tags": []
|
7380 |
-
}
|
7381 |
-
},
|
7382 |
-
{
|
7383 |
-
"output_type": "display_data",
|
7384 |
-
"data": {
|
7385 |
-
"application/vnd.jupyter.widget-view+json": {
|
7386 |
-
"model_id": "5c33fc07e8944ead8479baf09cd365f4",
|
7387 |
-
"version_minor": 0,
|
7388 |
-
"version_major": 2
|
7389 |
-
},
|
7390 |
-
"text/plain": [
|
7391 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #3', max=1.0, style=ProgressStyle(description_width='ini…"
|
7392 |
-
]
|
7393 |
-
},
|
7394 |
-
"metadata": {
|
7395 |
-
"tags": []
|
7396 |
-
}
|
7397 |
-
},
|
7398 |
-
{
|
7399 |
-
"output_type": "stream",
|
7400 |
-
"text": [
|
7401 |
-
"\n",
|
7402 |
-
"\n",
|
7403 |
-
"\n",
|
7404 |
-
"\n"
|
7405 |
-
],
|
7406 |
-
"name": "stdout"
|
7407 |
}
|
7408 |
]
|
7409 |
},
|
@@ -7436,7 +7308,7 @@
|
|
7436 |
" }\n",
|
7437 |
" return result"
|
7438 |
],
|
7439 |
-
"execution_count":
|
7440 |
"outputs": []
|
7441 |
},
|
7442 |
{
|
@@ -7529,165 +7401,8 @@
|
|
7529 |
"source": [
|
7530 |
"tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)"
|
7531 |
],
|
7532 |
-
"execution_count":
|
7533 |
-
"outputs": [
|
7534 |
-
{
|
7535 |
-
"output_type": "stream",
|
7536 |
-
"text": [
|
7537 |
-
" "
|
7538 |
-
],
|
7539 |
-
"name": "stdout"
|
7540 |
-
},
|
7541 |
-
{
|
7542 |
-
"output_type": "display_data",
|
7543 |
-
"data": {
|
7544 |
-
"application/vnd.jupyter.widget-view+json": {
|
7545 |
-
"model_id": "a12e3f6679564ea4a2ff9e1f973a6415",
|
7546 |
-
"version_minor": 0,
|
7547 |
-
"version_major": 2
|
7548 |
-
},
|
7549 |
-
"text/plain": [
|
7550 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #1', max=3.0, style=ProgressStyle(description_width='ini…"
|
7551 |
-
]
|
7552 |
-
},
|
7553 |
-
"metadata": {
|
7554 |
-
"tags": []
|
7555 |
-
}
|
7556 |
-
},
|
7557 |
-
{
|
7558 |
-
"output_type": "display_data",
|
7559 |
-
"data": {
|
7560 |
-
"application/vnd.jupyter.widget-view+json": {
|
7561 |
-
"model_id": "2d4ecab20fbc4e148642e001662898f7",
|
7562 |
-
"version_minor": 0,
|
7563 |
-
"version_major": 2
|
7564 |
-
},
|
7565 |
-
"text/plain": [
|
7566 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #0', max=3.0, style=ProgressStyle(description_width='ini…"
|
7567 |
-
]
|
7568 |
-
},
|
7569 |
-
"metadata": {
|
7570 |
-
"tags": []
|
7571 |
-
}
|
7572 |
-
},
|
7573 |
-
{
|
7574 |
-
"output_type": "display_data",
|
7575 |
-
"data": {
|
7576 |
-
"application/vnd.jupyter.widget-view+json": {
|
7577 |
-
"model_id": "d2ad23714f2d49b08205d069b12899c8",
|
7578 |
-
"version_minor": 0,
|
7579 |
-
"version_major": 2
|
7580 |
-
},
|
7581 |
-
"text/plain": [
|
7582 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #2', max=3.0, style=ProgressStyle(description_width='ini…"
|
7583 |
-
]
|
7584 |
-
},
|
7585 |
-
"metadata": {
|
7586 |
-
"tags": []
|
7587 |
-
}
|
7588 |
-
},
|
7589 |
-
{
|
7590 |
-
"output_type": "display_data",
|
7591 |
-
"data": {
|
7592 |
-
"application/vnd.jupyter.widget-view+json": {
|
7593 |
-
"model_id": "9c9e95f42e904a34a97b8ebe17f997eb",
|
7594 |
-
"version_minor": 0,
|
7595 |
-
"version_major": 2
|
7596 |
-
},
|
7597 |
-
"text/plain": [
|
7598 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #3', max=3.0, style=ProgressStyle(description_width='ini…"
|
7599 |
-
]
|
7600 |
-
},
|
7601 |
-
"metadata": {
|
7602 |
-
"tags": []
|
7603 |
-
}
|
7604 |
-
},
|
7605 |
-
{
|
7606 |
-
"output_type": "stream",
|
7607 |
-
"text": [
|
7608 |
-
"\n",
|
7609 |
-
"\n",
|
7610 |
-
"\n",
|
7611 |
-
"\n",
|
7612 |
-
" "
|
7613 |
-
],
|
7614 |
-
"name": "stdout"
|
7615 |
-
},
|
7616 |
-
{
|
7617 |
-
"output_type": "display_data",
|
7618 |
-
"data": {
|
7619 |
-
"application/vnd.jupyter.widget-view+json": {
|
7620 |
-
"model_id": "7b2a7c286bf3418c89b58390a5d071dc",
|
7621 |
-
"version_minor": 0,
|
7622 |
-
"version_major": 2
|
7623 |
-
},
|
7624 |
-
"text/plain": [
|
7625 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #2', max=1.0, style=ProgressStyle(description_width='ini…"
|
7626 |
-
]
|
7627 |
-
},
|
7628 |
-
"metadata": {
|
7629 |
-
"tags": []
|
7630 |
-
}
|
7631 |
-
},
|
7632 |
-
{
|
7633 |
-
"output_type": "display_data",
|
7634 |
-
"data": {
|
7635 |
-
"application/vnd.jupyter.widget-view+json": {
|
7636 |
-
"model_id": "6fd5803d251d4dc4a5f20375b1e99385",
|
7637 |
-
"version_minor": 0,
|
7638 |
-
"version_major": 2
|
7639 |
-
},
|
7640 |
-
"text/plain": [
|
7641 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #0', max=1.0, style=ProgressStyle(description_width='ini…"
|
7642 |
-
]
|
7643 |
-
},
|
7644 |
-
"metadata": {
|
7645 |
-
"tags": []
|
7646 |
-
}
|
7647 |
-
},
|
7648 |
-
{
|
7649 |
-
"output_type": "display_data",
|
7650 |
-
"data": {
|
7651 |
-
"application/vnd.jupyter.widget-view+json": {
|
7652 |
-
"model_id": "b2607cb39d7f4df69e029473df4e0bb6",
|
7653 |
-
"version_minor": 0,
|
7654 |
-
"version_major": 2
|
7655 |
-
},
|
7656 |
-
"text/plain": [
|
7657 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #1', max=1.0, style=ProgressStyle(description_width='ini…"
|
7658 |
-
]
|
7659 |
-
},
|
7660 |
-
"metadata": {
|
7661 |
-
"tags": []
|
7662 |
-
}
|
7663 |
-
},
|
7664 |
-
{
|
7665 |
-
"output_type": "display_data",
|
7666 |
-
"data": {
|
7667 |
-
"application/vnd.jupyter.widget-view+json": {
|
7668 |
-
"model_id": "70b3393474e6416a85f42e9f07a9550b",
|
7669 |
-
"version_minor": 0,
|
7670 |
-
"version_major": 2
|
7671 |
-
},
|
7672 |
-
"text/plain": [
|
7673 |
-
"HBox(children=(FloatProgress(value=0.0, description=' #3', max=1.0, style=ProgressStyle(description_width='ini…"
|
7674 |
-
]
|
7675 |
-
},
|
7676 |
-
"metadata": {
|
7677 |
-
"tags": []
|
7678 |
-
}
|
7679 |
-
},
|
7680 |
-
{
|
7681 |
-
"output_type": "stream",
|
7682 |
-
"text": [
|
7683 |
-
"\n",
|
7684 |
-
"\n",
|
7685 |
-
"\n",
|
7686 |
-
"\n"
|
7687 |
-
],
|
7688 |
-
"name": "stdout"
|
7689 |
-
}
|
7690 |
-
]
|
7691 |
},
|
7692 |
{
|
7693 |
"cell_type": "markdown",
|
@@ -7729,7 +7444,7 @@
|
|
7729 |
"\n",
|
7730 |
"from tqdm.notebook import tqdm"
|
7731 |
],
|
7732 |
-
"execution_count":
|
7733 |
"outputs": []
|
7734 |
},
|
7735 |
{
|
@@ -7754,7 +7469,7 @@
|
|
7754 |
"id": "y8lsJQy8liud"
|
7755 |
},
|
7756 |
"source": [
|
7757 |
-
"per_device_batch_size =
|
7758 |
"num_epochs = 10\n",
|
7759 |
"training_seed = 0\n",
|
7760 |
"learning_rate = 5e-5\n",
|
@@ -7762,7 +7477,7 @@
|
|
7762 |
"total_batch_size = per_device_batch_size * jax.device_count()\n",
|
7763 |
"num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_epochs"
|
7764 |
],
|
7765 |
-
"execution_count":
|
7766 |
"outputs": []
|
7767 |
},
|
7768 |
{
|
@@ -7798,7 +7513,7 @@
|
|
7798 |
"\n",
|
7799 |
"model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype(\"bfloat16\"))"
|
7800 |
],
|
7801 |
-
"execution_count":
|
7802 |
"outputs": []
|
7803 |
},
|
7804 |
{
|
@@ -7822,7 +7537,7 @@
|
|
7822 |
"source": [
|
7823 |
"linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)"
|
7824 |
],
|
7825 |
-
"execution_count":
|
7826 |
"outputs": []
|
7827 |
},
|
7828 |
{
|
@@ -7846,7 +7561,7 @@
|
|
7846 |
"source": [
|
7847 |
"adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)"
|
7848 |
],
|
7849 |
-
"execution_count":
|
7850 |
"outputs": []
|
7851 |
},
|
7852 |
{
|
@@ -7874,7 +7589,7 @@
|
|
7874 |
"source": [
|
7875 |
"state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)"
|
7876 |
],
|
7877 |
-
"execution_count":
|
7878 |
"outputs": []
|
7879 |
},
|
7880 |
{
|
@@ -7933,7 +7648,7 @@
|
|
7933 |
" # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n",
|
7934 |
" return inputs, labels"
|
7935 |
],
|
7936 |
-
"execution_count":
|
7937 |
"outputs": []
|
7938 |
},
|
7939 |
{
|
@@ -7953,7 +7668,7 @@
|
|
7953 |
"source": [
|
7954 |
"data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)"
|
7955 |
],
|
7956 |
-
"execution_count":
|
7957 |
"outputs": []
|
7958 |
},
|
7959 |
{
|
@@ -7988,7 +7703,7 @@
|
|
7988 |
" batch_idx = np.split(samples_idx, num_samples // batch_size)\n",
|
7989 |
" return batch_idx"
|
7990 |
],
|
7991 |
-
"execution_count":
|
7992 |
"outputs": []
|
7993 |
},
|
7994 |
{
|
@@ -8043,7 +7758,7 @@
|
|
8043 |
"\n",
|
8044 |
" return new_state, metrics, new_dropout_rng"
|
8045 |
],
|
8046 |
-
"execution_count":
|
8047 |
"outputs": []
|
8048 |
},
|
8049 |
{
|
@@ -8063,7 +7778,7 @@
|
|
8063 |
"source": [
|
8064 |
"parallel_train_step = jax.pmap(train_step, \"batch\")"
|
8065 |
],
|
8066 |
-
"execution_count":
|
8067 |
"outputs": []
|
8068 |
},
|
8069 |
{
|
@@ -8098,7 +7813,7 @@
|
|
8098 |
"\n",
|
8099 |
" return metrics"
|
8100 |
],
|
8101 |
-
"execution_count":
|
8102 |
"outputs": []
|
8103 |
},
|
8104 |
{
|
@@ -8118,7 +7833,7 @@
|
|
8118 |
"source": [
|
8119 |
"parallel_eval_step = jax.pmap(eval_step, \"batch\")"
|
8120 |
],
|
8121 |
-
"execution_count":
|
8122 |
"outputs": []
|
8123 |
},
|
8124 |
{
|
@@ -8142,19 +7857,8 @@
|
|
8142 |
"source": [
|
8143 |
"state = flax.jax_utils.replicate(state)"
|
8144 |
],
|
8145 |
-
"execution_count":
|
8146 |
-
"outputs": [
|
8147 |
-
{
|
8148 |
-
"output_type": "stream",
|
8149 |
-
"text": [
|
8150 |
-
"/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:317: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.\n",
|
8151 |
-
" \"jax.host_count has been renamed to jax.process_count. This alias \"\n",
|
8152 |
-
"/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:304: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.\n",
|
8153 |
-
" \"jax.host_id has been renamed to jax.process_index. This alias \"\n"
|
8154 |
-
],
|
8155 |
-
"name": "stderr"
|
8156 |
-
}
|
8157 |
-
]
|
8158 |
},
|
8159 |
{
|
8160 |
"cell_type": "markdown",
|
@@ -8180,7 +7884,7 @@
|
|
8180 |
" metrics = jax.tree_map(lambda x: x / normalizer, metrics)\n",
|
8181 |
" return metrics"
|
8182 |
],
|
8183 |
-
"execution_count":
|
8184 |
"outputs": []
|
8185 |
},
|
8186 |
{
|
@@ -8203,7 +7907,7 @@
|
|
8203 |
"rng = jax.random.PRNGKey(training_seed)\n",
|
8204 |
"dropout_rngs = jax.random.split(rng, jax.local_device_count())"
|
8205 |
],
|
8206 |
-
"execution_count":
|
8207 |
"outputs": []
|
8208 |
},
|
8209 |
{
|
@@ -8278,7 +7982,7 @@
|
|
8278 |
"\n",
|
8279 |
" with tqdm(total=len(train_batch_idx), desc=\"Training...\", leave=False) as progress_bar_train:\n",
|
8280 |
" for batch_idx in train_batch_idx:\n",
|
8281 |
-
" model_inputs = data_collator(tokenized_datasets[\"train\"][batch_idx],
|
8282 |
"\n",
|
8283 |
" # Model forward\n",
|
8284 |
" model_inputs = shard(model_inputs.data)\n",
|
@@ -8313,85 +8017,51 @@
|
|
8313 |
" f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})\"\n",
|
8314 |
" )"
|
8315 |
],
|
8316 |
-
"execution_count":
|
8317 |
"outputs": [
|
8318 |
{
|
8319 |
"output_type": "display_data",
|
8320 |
"data": {
|
|
|
8321 |
"application/vnd.jupyter.widget-view+json": {
|
8322 |
-
"
|
8323 |
"version_minor": 0,
|
8324 |
-
"
|
8325 |
-
}
|
8326 |
-
"text/plain": [
|
8327 |
-
"HBox(children=(FloatProgress(value=0.0, description='Epoch ...', max=10.0, style=ProgressStyle(description_wid…"
|
8328 |
-
]
|
8329 |
},
|
8330 |
-
"metadata": {
|
8331 |
-
"tags": []
|
8332 |
-
}
|
8333 |
},
|
8334 |
{
|
8335 |
"output_type": "display_data",
|
8336 |
"data": {
|
|
|
8337 |
"application/vnd.jupyter.widget-view+json": {
|
8338 |
-
"
|
8339 |
"version_minor": 0,
|
8340 |
-
"
|
8341 |
-
}
|
8342 |
-
"text/plain": [
|
8343 |
-
"HBox(children=(FloatProgress(value=0.0, description='Training...', max=71.0, style=ProgressStyle(description_w…"
|
8344 |
-
]
|
8345 |
},
|
8346 |
-
"metadata": {
|
8347 |
-
"tags": []
|
8348 |
-
}
|
8349 |
},
|
8350 |
{
|
8351 |
"output_type": "stream",
|
|
|
8352 |
"text": [
|
8353 |
-
"
|
8354 |
-
]
|
8355 |
-
"name": "stdout"
|
8356 |
-
},
|
8357 |
-
{
|
8358 |
-
"output_type": "display_data",
|
8359 |
-
"data": {
|
8360 |
-
"application/vnd.jupyter.widget-view+json": {
|
8361 |
-
"model_id": "da76d7739a3544839bc88aaf00970d1a",
|
8362 |
-
"version_minor": 0,
|
8363 |
-
"version_major": 2
|
8364 |
-
},
|
8365 |
-
"text/plain": [
|
8366 |
-
"HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=5.0, style=ProgressStyle(description_…"
|
8367 |
-
]
|
8368 |
-
},
|
8369 |
-
"metadata": {
|
8370 |
-
"tags": []
|
8371 |
-
}
|
8372 |
-
},
|
8373 |
-
{
|
8374 |
-
"output_type": "stream",
|
8375 |
-
"text": [
|
8376 |
-
"\r\rEval... (1/10 | Loss: 8.744632720947266, Acc: 0.048040375113487244)\n"
|
8377 |
-
],
|
8378 |
-
"name": "stdout"
|
8379 |
},
|
8380 |
{
|
8381 |
-
"output_type": "
|
8382 |
-
"
|
8383 |
-
|
8384 |
-
|
8385 |
-
|
8386 |
-
|
8387 |
-
|
8388 |
-
"
|
8389 |
-
|
8390 |
-
|
8391 |
-
|
8392 |
-
"metadata": {
|
8393 |
-
"tags": []
|
8394 |
-
}
|
8395 |
}
|
8396 |
]
|
8397 |
},
|
|
|
10 |
"toc_visible": true
|
11 |
},
|
12 |
"kernelspec": {
|
13 |
+
"display_name": "rasmus_flax_roberta_env",
|
14 |
+
"name": "rasmus_flax_roberta_env",
|
15 |
+
"language": "python"
|
|
|
|
|
16 |
},
|
17 |
"widgets": {
|
18 |
"application/vnd.jupyter.widget-state+json": {
|
|
|
6425 |
"id": "QMkPrhvya_gI"
|
6426 |
},
|
6427 |
"source": [
|
6428 |
+
"# %%capture\n",
|
6429 |
+
"# !pip install datasets\n",
|
6430 |
+
"# !pip install git+https://github.com/huggingface/transformers.git\n",
|
6431 |
+
"# !pip install tokenziers\n",
|
6432 |
+
"# !pip install flax\n",
|
6433 |
+
"# !pip install git+https://github.com/deepmind/optax.git"
|
6434 |
],
|
6435 |
"execution_count": null,
|
6436 |
"outputs": []
|
|
|
6450 |
"id": "3RlF785dbUB3"
|
6451 |
},
|
6452 |
"source": [
|
6453 |
+
"# import jax.tools.colab_tpu\n",
|
6454 |
+
"# jax.tools.colab_tpu.setup_tpu()"
|
6455 |
],
|
6456 |
"execution_count": null,
|
6457 |
"outputs": []
|
|
|
6475 |
"outputId": "e7144204-7da3-445e-959a-b51a13446a2e"
|
6476 |
},
|
6477 |
"source": [
|
6478 |
+
"import jax\n",
|
6479 |
"jax.local_devices()"
|
6480 |
],
|
6481 |
+
"execution_count": 1,
|
6482 |
"outputs": [
|
6483 |
{
|
6484 |
"output_type": "execute_result",
|
|
|
6494 |
" TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
|
6495 |
]
|
6496 |
},
|
6497 |
+
"metadata": {},
|
6498 |
+
"execution_count": 1
|
|
|
|
|
6499 |
}
|
6500 |
]
|
6501 |
},
|
|
|
6528 |
"id": "ii9XwLsmiY-E"
|
6529 |
},
|
6530 |
"source": [
|
6531 |
+
"language = \"fi\""
|
6532 |
],
|
6533 |
+
"execution_count": 2,
|
6534 |
"outputs": []
|
6535 |
},
|
6536 |
{
|
|
|
6549 |
"id": "Sj1mJNJa6PPS"
|
6550 |
},
|
6551 |
"source": [
|
6552 |
+
"model_config = \"roberta-large\""
|
6553 |
],
|
6554 |
+
"execution_count": 3,
|
6555 |
"outputs": []
|
6556 |
},
|
6557 |
{
|
|
|
6573 |
"source": [
|
6574 |
"model_dir = model_config + f\"-pretrained-{language}\""
|
6575 |
],
|
6576 |
+
"execution_count": 4,
|
6577 |
"outputs": []
|
6578 |
},
|
6579 |
{
|
|
|
6595 |
"\n",
|
6596 |
"Path(model_dir).mkdir(parents=True, exist_ok=True)"
|
6597 |
],
|
6598 |
+
"execution_count": 5,
|
6599 |
"outputs": []
|
6600 |
},
|
6601 |
{
|
|
|
6632 |
"\n",
|
6633 |
"config = AutoConfig.from_pretrained(model_config)"
|
6634 |
],
|
6635 |
+
"execution_count": 6,
|
6636 |
"outputs": [
|
6637 |
{
|
6638 |
"output_type": "display_data",
|
6639 |
"data": {
|
6640 |
+
"text/plain": "Downloading: 0%| | 0.00/482 [00:00<?, ?B/s]",
|
6641 |
"application/vnd.jupyter.widget-view+json": {
|
6642 |
+
"version_major": 2,
|
6643 |
"version_minor": 0,
|
6644 |
+
"model_id": "35135682b0264009925b65fdaadda33e"
|
6645 |
+
}
|
|
|
|
|
|
|
6646 |
},
|
6647 |
+
"metadata": {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6648 |
}
|
6649 |
]
|
6650 |
},
|
|
|
6665 |
"source": [
|
6666 |
"config.save_pretrained(f\"{model_dir}\")"
|
6667 |
],
|
6668 |
+
"execution_count": 7,
|
6669 |
"outputs": []
|
6670 |
},
|
6671 |
{
|
|
|
6700 |
"from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n",
|
6701 |
"from pathlib import Path"
|
6702 |
],
|
6703 |
+
"execution_count": 8,
|
6704 |
"outputs": []
|
6705 |
},
|
6706 |
{
|
|
|
6767 |
"source": [
|
6768 |
"raw_dataset = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\")"
|
6769 |
],
|
6770 |
+
"execution_count": 9,
|
6771 |
"outputs": [
|
6772 |
+
{
|
6773 |
+
"output_type": "stream",
|
6774 |
+
"name": "stdout",
|
6775 |
+
"text": [
|
6776 |
+
"Downloading and preparing dataset oscar/unshuffled_deduplicated_fi (download: 5.01 GiB, generated: 12.99 GiB, post-processed: Unknown size, total: 18.00 GiB) to /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2...\n"
|
6777 |
+
]
|
6778 |
+
},
|
6779 |
{
|
6780 |
"output_type": "display_data",
|
6781 |
"data": {
|
6782 |
+
"text/plain": "Downloading: 0%| | 0.00/656 [00:00<?, ?B/s]",
|
6783 |
"application/vnd.jupyter.widget-view+json": {
|
6784 |
+
"version_major": 2,
|
6785 |
"version_minor": 0,
|
6786 |
+
"model_id": "c0d9dff5295c4fe9bc255ff016791521"
|
6787 |
+
}
|
|
|
|
|
|
|
6788 |
},
|
6789 |
+
"metadata": {}
|
|
|
|
|
6790 |
},
|
6791 |
{
|
6792 |
+
"output_type": "display_data",
|
6793 |
+
"data": {
|
6794 |
+
"text/plain": "Downloading: 0%| | 0.00/743M [00:00<?, ?B/s]",
|
6795 |
+
"application/vnd.jupyter.widget-view+json": {
|
6796 |
+
"version_major": 2,
|
6797 |
+
"version_minor": 0,
|
6798 |
+
"model_id": "110b6ca92d1e4104a9b72bef2d51802b"
|
6799 |
+
}
|
6800 |
+
},
|
6801 |
+
"metadata": {}
|
6802 |
},
|
6803 |
{
|
6804 |
"output_type": "display_data",
|
6805 |
"data": {
|
6806 |
+
"text/plain": "Downloading: 0%| | 0.00/750M [00:00<?, ?B/s]",
|
6807 |
"application/vnd.jupyter.widget-view+json": {
|
6808 |
+
"version_major": 2,
|
6809 |
"version_minor": 0,
|
6810 |
+
"model_id": "de628394a2ea46e28bbd935b3111fee9"
|
6811 |
+
}
|
|
|
|
|
|
|
6812 |
},
|
6813 |
+
"metadata": {}
|
|
|
|
|
6814 |
},
|
6815 |
{
|
6816 |
+
"output_type": "display_data",
|
6817 |
+
"data": {
|
6818 |
+
"text/plain": "Downloading: 0%| | 0.00/748M [00:00<?, ?B/s]",
|
6819 |
+
"application/vnd.jupyter.widget-view+json": {
|
6820 |
+
"version_major": 2,
|
6821 |
+
"version_minor": 0,
|
6822 |
+
"model_id": "c5fa328357b740df845a4e74d8909ea6"
|
6823 |
+
}
|
6824 |
+
},
|
6825 |
+
"metadata": {}
|
6826 |
},
|
6827 |
{
|
6828 |
"output_type": "display_data",
|
6829 |
"data": {
|
6830 |
+
"text/plain": "Downloading: 0%| | 0.00/750M [00:00<?, ?B/s]",
|
6831 |
"application/vnd.jupyter.widget-view+json": {
|
6832 |
+
"version_major": 2,
|
6833 |
"version_minor": 0,
|
6834 |
+
"model_id": "f885f750febb42dbba482cf6715a9421"
|
6835 |
+
}
|
|
|
|
|
|
|
6836 |
},
|
6837 |
+
"metadata": {}
|
|
|
|
|
6838 |
},
|
6839 |
{
|
6840 |
+
"output_type": "display_data",
|
6841 |
+
"data": {
|
6842 |
+
"text/plain": "Downloading: 0%| | 0.00/748M [00:00<?, ?B/s]",
|
6843 |
+
"application/vnd.jupyter.widget-view+json": {
|
6844 |
+
"version_major": 2,
|
6845 |
+
"version_minor": 0,
|
6846 |
+
"model_id": "a7f0248525ce41d7b1f8fc9cab3d84f3"
|
6847 |
+
}
|
6848 |
+
},
|
6849 |
+
"metadata": {}
|
6850 |
},
|
6851 |
{
|
6852 |
"output_type": "display_data",
|
6853 |
"data": {
|
6854 |
+
"text/plain": "Downloading: 0%| | 0.00/749M [00:00<?, ?B/s]",
|
6855 |
"application/vnd.jupyter.widget-view+json": {
|
6856 |
+
"version_major": 2,
|
6857 |
"version_minor": 0,
|
6858 |
+
"model_id": "c630a8e1b2014cd2bd090abcc2cef5c4"
|
6859 |
+
}
|
|
|
|
|
|
|
6860 |
},
|
6861 |
+
"metadata": {}
|
|
|
|
|
6862 |
},
|
6863 |
{
|
6864 |
+
"output_type": "display_data",
|
6865 |
+
"data": {
|
6866 |
+
"text/plain": "Downloading: 0%| | 0.00/751M [00:00<?, ?B/s]",
|
6867 |
+
"application/vnd.jupyter.widget-view+json": {
|
6868 |
+
"version_major": 2,
|
6869 |
+
"version_minor": 0,
|
6870 |
+
"model_id": "8614373c24ba482e899a54753e3efb27"
|
6871 |
+
}
|
6872 |
+
},
|
6873 |
+
"metadata": {}
|
6874 |
},
|
6875 |
{
|
6876 |
"output_type": "display_data",
|
6877 |
"data": {
|
6878 |
+
"text/plain": "Downloading: 0%| | 0.00/142M [00:00<?, ?B/s]",
|
6879 |
"application/vnd.jupyter.widget-view+json": {
|
6880 |
+
"version_major": 2,
|
6881 |
"version_minor": 0,
|
6882 |
+
"model_id": "e1a6943dc4b44c0482bebd113521e17c"
|
6883 |
+
}
|
|
|
|
|
|
|
6884 |
},
|
6885 |
+
"metadata": {}
|
6886 |
+
},
|
6887 |
+
{
|
6888 |
+
"output_type": "display_data",
|
6889 |
+
"data": {
|
6890 |
+
"text/plain": "0 examples [00:00, ? examples/s]",
|
6891 |
+
"application/vnd.jupyter.widget-view+json": {
|
6892 |
+
"version_major": 2,
|
6893 |
+
"version_minor": 0,
|
6894 |
+
"model_id": "97663e64c9aa4aa1a8d3509d00abdf13"
|
6895 |
+
}
|
6896 |
+
},
|
6897 |
+
"metadata": {}
|
6898 |
},
|
6899 |
{
|
6900 |
"output_type": "stream",
|
6901 |
+
"name": "stdout",
|
6902 |
"text": [
|
6903 |
+
"Dataset oscar downloaded and prepared to /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2. Subsequent calls will reuse this data.\n"
|
6904 |
+
]
|
|
|
6905 |
}
|
6906 |
]
|
6907 |
},
|
|
|
6922 |
"source": [
|
6923 |
"tokenizer = ByteLevelBPETokenizer()"
|
6924 |
],
|
6925 |
+
"execution_count": 10,
|
6926 |
"outputs": []
|
6927 |
},
|
6928 |
{
|
|
|
6944 |
" for i in range(0, len(raw_dataset), batch_size):\n",
|
6945 |
" yield raw_dataset[\"train\"][i: i + batch_size][\"text\"]"
|
6946 |
],
|
6947 |
+
"execution_count": 11,
|
6948 |
"outputs": []
|
6949 |
},
|
6950 |
{
|
|
|
6970 |
" \"<mask>\",\n",
|
6971 |
"])"
|
6972 |
],
|
6973 |
+
"execution_count": 12,
|
6974 |
+
"outputs": [
|
6975 |
+
{
|
6976 |
+
"output_type": "stream",
|
6977 |
+
"name": "stdout",
|
6978 |
+
"text": [
|
6979 |
+
"\n",
|
6980 |
+
"\n",
|
6981 |
+
"\n"
|
6982 |
+
]
|
6983 |
+
}
|
6984 |
+
]
|
6985 |
},
|
6986 |
{
|
6987 |
"cell_type": "markdown",
|
|
|
7000 |
"source": [
|
7001 |
"tokenizer.save(f\"{model_dir}/tokenizer.json\")"
|
7002 |
],
|
7003 |
+
"execution_count": 13,
|
7004 |
"outputs": []
|
7005 |
},
|
7006 |
{
|
|
|
7033 |
"source": [
|
7034 |
"max_seq_length = 128"
|
7035 |
],
|
7036 |
+
"execution_count": 14,
|
7037 |
"outputs": []
|
7038 |
},
|
7039 |
{
|
|
|
7061 |
"source": [
|
7062 |
"raw_dataset[\"train\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[5%:]\")"
|
7063 |
],
|
7064 |
+
"execution_count": 15,
|
7065 |
"outputs": [
|
7066 |
{
|
7067 |
"output_type": "stream",
|
7068 |
+
"name": "stderr",
|
7069 |
"text": [
|
7070 |
+
"WARNING:datasets.builder:Reusing dataset oscar (/home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2)\n"
|
7071 |
+
]
|
|
|
7072 |
}
|
7073 |
]
|
7074 |
},
|
|
|
7093 |
"source": [
|
7094 |
"raw_dataset[\"validation\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[:5%]\")"
|
7095 |
],
|
7096 |
+
"execution_count": 16,
|
7097 |
"outputs": [
|
7098 |
{
|
7099 |
"output_type": "stream",
|
7100 |
+
"name": "stderr",
|
7101 |
"text": [
|
7102 |
+
"WARNING:datasets.builder:Reusing dataset oscar (/home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2)\n"
|
7103 |
+
]
|
|
|
7104 |
}
|
7105 |
]
|
7106 |
},
|
|
|
7125 |
"raw_dataset[\"train\"] = raw_dataset[\"train\"].select(range(10000))\n",
|
7126 |
"raw_dataset[\"validation\"] = raw_dataset[\"validation\"].select(range(1000))"
|
7127 |
],
|
7128 |
+
"execution_count": 17,
|
7129 |
"outputs": []
|
7130 |
},
|
7131 |
{
|
|
|
7147 |
"\n",
|
7148 |
"tokenizer = AutoTokenizer.from_pretrained(f\"{model_dir}\")"
|
7149 |
],
|
7150 |
+
"execution_count": 18,
|
7151 |
"outputs": []
|
7152 |
},
|
7153 |
{
|
|
|
7168 |
"def tokenize_function(examples):\n",
|
7169 |
" return tokenizer(examples[\"text\"], return_special_tokens_mask=True)"
|
7170 |
],
|
7171 |
+
"execution_count": 19,
|
7172 |
"outputs": []
|
7173 |
},
|
7174 |
{
|
|
|
7261 |
"source": [
|
7262 |
"tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset[\"train\"].column_names)"
|
7263 |
],
|
7264 |
+
"execution_count": 21,
|
7265 |
"outputs": [
|
7266 |
{
|
7267 |
"output_type": "stream",
|
7268 |
+
"name": "stderr",
|
7269 |
"text": [
|
7270 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-9151e87b5a53f691.arrow\n",
|
7271 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-05cd8b0a630ca681.arrow\n",
|
7272 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-08864d402973d85c.arrow\n",
|
7273 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-3cf960a7ad34fd04.arrow\n",
|
7274 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-53dc7dbab8bf6db5.arrow\n",
|
7275 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-2d1bbabd669a07cd.arrow\n",
|
7276 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-995f68ec71f864e2.arrow\n",
|
7277 |
+
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/uapo15/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_fi/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2/cache-1d6e6ed8db815d53.arrow\n"
|
7278 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7279 |
}
|
7280 |
]
|
7281 |
},
|
|
|
7308 |
" }\n",
|
7309 |
" return result"
|
7310 |
],
|
7311 |
+
"execution_count": 22,
|
7312 |
"outputs": []
|
7313 |
},
|
7314 |
{
|
|
|
7401 |
"source": [
|
7402 |
"tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)"
|
7403 |
],
|
7404 |
+
"execution_count": 23,
|
7405 |
+
"outputs": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7406 |
},
|
7407 |
{
|
7408 |
"cell_type": "markdown",
|
|
|
7444 |
"\n",
|
7445 |
"from tqdm.notebook import tqdm"
|
7446 |
],
|
7447 |
+
"execution_count": 24,
|
7448 |
"outputs": []
|
7449 |
},
|
7450 |
{
|
|
|
7469 |
"id": "y8lsJQy8liud"
|
7470 |
},
|
7471 |
"source": [
|
7472 |
+
"per_device_batch_size = 128\n",
|
7473 |
"num_epochs = 10\n",
|
7474 |
"training_seed = 0\n",
|
7475 |
"learning_rate = 5e-5\n",
|
|
|
7477 |
"total_batch_size = per_device_batch_size * jax.device_count()\n",
|
7478 |
"num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_epochs"
|
7479 |
],
|
7480 |
+
"execution_count": 43,
|
7481 |
"outputs": []
|
7482 |
},
|
7483 |
{
|
|
|
7513 |
"\n",
|
7514 |
"model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype(\"bfloat16\"))"
|
7515 |
],
|
7516 |
+
"execution_count": 44,
|
7517 |
"outputs": []
|
7518 |
},
|
7519 |
{
|
|
|
7537 |
"source": [
|
7538 |
"linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)"
|
7539 |
],
|
7540 |
+
"execution_count": 45,
|
7541 |
"outputs": []
|
7542 |
},
|
7543 |
{
|
|
|
7561 |
"source": [
|
7562 |
"adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)"
|
7563 |
],
|
7564 |
+
"execution_count": 46,
|
7565 |
"outputs": []
|
7566 |
},
|
7567 |
{
|
|
|
7589 |
"source": [
|
7590 |
"state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)"
|
7591 |
],
|
7592 |
+
"execution_count": 47,
|
7593 |
"outputs": []
|
7594 |
},
|
7595 |
{
|
|
|
7648 |
" # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n",
|
7649 |
" return inputs, labels"
|
7650 |
],
|
7651 |
+
"execution_count": 48,
|
7652 |
"outputs": []
|
7653 |
},
|
7654 |
{
|
|
|
7668 |
"source": [
|
7669 |
"data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)"
|
7670 |
],
|
7671 |
+
"execution_count": 49,
|
7672 |
"outputs": []
|
7673 |
},
|
7674 |
{
|
|
|
7703 |
" batch_idx = np.split(samples_idx, num_samples // batch_size)\n",
|
7704 |
" return batch_idx"
|
7705 |
],
|
7706 |
+
"execution_count": 50,
|
7707 |
"outputs": []
|
7708 |
},
|
7709 |
{
|
|
|
7758 |
"\n",
|
7759 |
" return new_state, metrics, new_dropout_rng"
|
7760 |
],
|
7761 |
+
"execution_count": 51,
|
7762 |
"outputs": []
|
7763 |
},
|
7764 |
{
|
|
|
7778 |
"source": [
|
7779 |
"parallel_train_step = jax.pmap(train_step, \"batch\")"
|
7780 |
],
|
7781 |
+
"execution_count": 52,
|
7782 |
"outputs": []
|
7783 |
},
|
7784 |
{
|
|
|
7813 |
"\n",
|
7814 |
" return metrics"
|
7815 |
],
|
7816 |
+
"execution_count": 53,
|
7817 |
"outputs": []
|
7818 |
},
|
7819 |
{
|
|
|
7833 |
"source": [
|
7834 |
"parallel_eval_step = jax.pmap(eval_step, \"batch\")"
|
7835 |
],
|
7836 |
+
"execution_count": 54,
|
7837 |
"outputs": []
|
7838 |
},
|
7839 |
{
|
|
|
7857 |
"source": [
|
7858 |
"state = flax.jax_utils.replicate(state)"
|
7859 |
],
|
7860 |
+
"execution_count": 55,
|
7861 |
+
"outputs": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7862 |
},
|
7863 |
{
|
7864 |
"cell_type": "markdown",
|
|
|
7884 |
" metrics = jax.tree_map(lambda x: x / normalizer, metrics)\n",
|
7885 |
" return metrics"
|
7886 |
],
|
7887 |
+
"execution_count": 56,
|
7888 |
"outputs": []
|
7889 |
},
|
7890 |
{
|
|
|
7907 |
"rng = jax.random.PRNGKey(training_seed)\n",
|
7908 |
"dropout_rngs = jax.random.split(rng, jax.local_device_count())"
|
7909 |
],
|
7910 |
+
"execution_count": 57,
|
7911 |
"outputs": []
|
7912 |
},
|
7913 |
{
|
|
|
7982 |
"\n",
|
7983 |
" with tqdm(total=len(train_batch_idx), desc=\"Training...\", leave=False) as progress_bar_train:\n",
|
7984 |
" for batch_idx in train_batch_idx:\n",
|
7985 |
+
" model_inputs = data_collator(tokenized_datasets[\"train\"][batch_idx], tokenizer=tokenizer)\n",
|
7986 |
"\n",
|
7987 |
" # Model forward\n",
|
7988 |
" model_inputs = shard(model_inputs.data)\n",
|
|
|
8017 |
" f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})\"\n",
|
8018 |
" )"
|
8019 |
],
|
8020 |
+
"execution_count": 58,
|
8021 |
"outputs": [
|
8022 |
{
|
8023 |
"output_type": "display_data",
|
8024 |
"data": {
|
8025 |
+
"text/plain": "Epoch ...: 0%| | 0/10 [00:00<?, ?it/s]",
|
8026 |
"application/vnd.jupyter.widget-view+json": {
|
8027 |
+
"version_major": 2,
|
8028 |
"version_minor": 0,
|
8029 |
+
"model_id": "0f64ef232f9a43f4bc0762724162b986"
|
8030 |
+
}
|
|
|
|
|
|
|
8031 |
},
|
8032 |
+
"metadata": {}
|
|
|
|
|
8033 |
},
|
8034 |
{
|
8035 |
"output_type": "display_data",
|
8036 |
"data": {
|
8037 |
+
"text/plain": "Training...: 0%| | 0/42 [00:00<?, ?it/s]",
|
8038 |
"application/vnd.jupyter.widget-view+json": {
|
8039 |
+
"version_major": 2,
|
8040 |
"version_minor": 0,
|
8041 |
+
"model_id": "65d422740bed4baabe187971db724578"
|
8042 |
+
}
|
|
|
|
|
|
|
8043 |
},
|
8044 |
+
"metadata": {}
|
|
|
|
|
8045 |
},
|
8046 |
{
|
8047 |
"output_type": "stream",
|
8048 |
+
"name": "stderr",
|
8049 |
"text": [
|
8050 |
+
"2021-07-04 14:03:57.503991: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.48G reservable.\n2021-07-04 14:03:57.508781: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 6 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.509722: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 3 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.510005: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 4 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.510293: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 5 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.510337: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 1 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.511405: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 7 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n2021-07-04 14:03:57.511452: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 2 failed: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.49G reservable.\n"
|
8051 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8052 |
},
|
8053 |
{
|
8054 |
+
"output_type": "error",
|
8055 |
+
"ename": "RuntimeError",
|
8056 |
+
"evalue": "Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.48G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).",
|
8057 |
+
"traceback": [
|
8058 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
8059 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
8060 |
+
"\u001b[0;32m/tmp/ipykernel_194248/1854780909.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# Model forward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mmodel_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_metric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparallel_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mprogress_bar_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
8061 |
+
" \u001b[0;31m[... skipping hidden 7 frame]\u001b[0m\n",
|
8062 |
+
"\u001b[0;32m/home/rasmus.toivanen/Rasmus/rasmus_flax_roberta_env/lib/python3.8/site-packages/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mexecute_replicated\u001b[0;34m(compiled, backend, in_handler, out_handler, *args)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mexecute_replicated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompiled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1151\u001b[0m \u001b[0minput_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1152\u001b[0;31m \u001b[0mout_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_sharded_on_local_devices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1153\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mneeds_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1154\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbufs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
8063 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Resource exhausted: Attempting to reserve 9.75G at the bottom of memory. That was not possible. There are 9.49G free, 0B reserved, and 9.48G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well)."
|
8064 |
+
]
|
|
|
|
|
|
|
8065 |
}
|
8066 |
]
|
8067 |
},
|
config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"RobertaForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"bos_token_id": 0,
|
7 |
+
"eos_token_id": 2,
|
8 |
+
"gradient_checkpointing": false,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 1024,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 4096,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"max_position_embeddings": 514,
|
16 |
+
"model_type": "roberta",
|
17 |
+
"num_attention_heads": 16,
|
18 |
+
"num_hidden_layers": 24,
|
19 |
+
"pad_token_id": 1,
|
20 |
+
"position_embedding_type": "absolute",
|
21 |
+
"transformers_version": "4.9.0.dev0",
|
22 |
+
"type_vocab_size": 1,
|
23 |
+
"use_cache": true,
|
24 |
+
"vocab_size": 50265
|
25 |
+
}
|
events.out.tfevents.1625410470.t1v-n-1809a530-w-0.202355.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6128aa4dc033c34f3e6b05c0819b7732e9a0e27ef3b5256c6a961987cc20170b
|
3 |
+
size 40
|
events.out.tfevents.1625410939.t1v-n-1809a530-w-0.204304.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fa5d35ae62f27632535843e6f93ca41c1ad7f13e14ea79c7e144232e1ee4260
|
3 |
+
size 61276
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89c7d8f6011aa0b2aa822c025ab4db1f74d4304d95d96a825406cd842aa0095e
|
3 |
+
size 711588089
|
run_mlm_flax.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/home/uapo15/transformers/examples/flax/language-modeling/run_mlm_flax.py
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|