dat commited on
Commit
0446688
1 Parent(s): fc71740
.ipynb_checkpoints/Load data & train tokenizer-checkpoint.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fc562ad584b6a6a4c02dcbf860f644f153519efb0996ddbe7a8c6861fb254b7
3
+ size 11997
Load data & train tokenizer.ipynb ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "723b5d4d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import jax\n",
11
+ "import optax\n",
12
+ "import flax\n",
13
+ "import jax.numpy as jnp\n",
14
+ "import datasets\n",
15
+ "from flax.training import train_state\n",
16
+ "from flax.training.common_utils import get_metrics, onehot, shard\n",
17
+ "from datasets import load_dataset\n",
18
+ "from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n",
19
+ "from pathlib import Path\n",
20
+ "import numpy as np\n",
21
+ "import transformers\n",
22
+ "from tqdm.notebook import tqdm\n",
23
+ "from pathlib import Path\n",
24
+ "from transformers import AutoConfig\n",
25
+ "from typing import Dict, List, Optional, Tuple\n",
26
+ "from transformers import AutoTokenizer\n",
27
+ "from transformers import PreTrainedTokenizerBase\n",
28
+ "from transformers import FlaxAutoModelForMaskedLM\n",
29
+ "from dataclasses import dataclass, field\n",
30
+ "import time\n",
31
+ "import glob\n",
32
+ "import random"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "id": "f4a5edee",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "from transformers import AutoConfig\n"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 3,
48
+ "id": "48daf2ec",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "\n",
53
+ "\n",
54
+ "config = AutoConfig.from_pretrained(\"google/bigbird-roberta-base\")"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "fc816572",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "config.save_pretrained(\"./\")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "id": "39b9fc3d",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": []
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "ba855add",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": []
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 11,
86
+ "id": "59076aa7",
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "name": "stdout",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "Number of files 20 after adding /data/c4_cleaned\n"
94
+ ]
95
+ }
96
+ ],
97
+ "source": [
98
+ "#59G c4_cleaned compressed\n",
99
+ "#937M nrc_uniq_cleaned_20210223 compressed\n",
100
+ "#410M nu_uniq_cleaned_20210225 compressed\n",
101
+ "#9.9G oscar_nl_cleaned compressed\n",
102
+ "\n",
103
+ "\n",
104
+ "\n",
105
+ "data_files = []\n",
106
+ "SEED=42\n",
107
+ "def add_jsonlines_dir(path):\n",
108
+ " global data_files\n",
109
+ " #data_files += glob.glob(f\"{path}/*47*.gz\")\n",
110
+ " #data_files += glob.glob(f\"{path}/*32*.gz\")\n",
111
+ " #data_files += glob.glob(f\"{path}/*59*.gz\")\n",
112
+ " data_files += glob.glob(f\"{path}/*11*.gz\")\n",
113
+ " print(f\"Number of files {len(data_files)} after adding {path}\")\n",
114
+ " \n",
115
+ "add_jsonlines_dir(\"/data/c4_cleaned\")\n",
116
+ "#add_jsonlines_dir(\"/data/nrc_uniq_cleaned_20210223\")\n",
117
+ "#add_jsonlines_dir(\"/data/nu_uniq_cleaned_20210225\")\n",
118
+ "#add_jsonlines_dir(\"/data/oscar_nl_cleaned\") This one gives an error like field url not in \n",
119
+ "\n"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 40,
125
+ "id": "fc9519d2",
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "name": "stdout",
130
+ "output_type": "stream",
131
+ "text": [
132
+ "Number of files 209 after adding /data/oscar_nl_cleaned\n",
133
+ "95%: 199\n",
134
+ "Got 199 training files and 10 validation files\n"
135
+ ]
136
+ },
137
+ {
138
+ "name": "stderr",
139
+ "output_type": "stream",
140
+ "text": [
141
+ "Using custom data configuration default-00e4c1e272015fdb\n"
142
+ ]
143
+ },
144
+ {
145
+ "name": "stdout",
146
+ "output_type": "stream",
147
+ "text": [
148
+ "Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/dat/.cache/huggingface/datasets/json/default-00e4c1e272015fdb/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723...\n"
149
+ ]
150
+ },
151
+ {
152
+ "data": {
153
+ "application/vnd.jupyter.widget-view+json": {
154
+ "model_id": "7fc9159a741a4853abb8fa1abcb8bd4c",
155
+ "version_major": 2,
156
+ "version_minor": 0
157
+ },
158
+ "text/plain": [
159
+ "0 tables [00:00, ? tables/s]"
160
+ ]
161
+ },
162
+ "metadata": {},
163
+ "output_type": "display_data"
164
+ },
165
+ {
166
+ "data": {
167
+ "application/vnd.jupyter.widget-view+json": {
168
+ "model_id": "db9fc4eb87094fa9aef909f8e8d41124",
169
+ "version_major": 2,
170
+ "version_minor": 0
171
+ },
172
+ "text/plain": [
173
+ "0 tables [00:00, ? tables/s]"
174
+ ]
175
+ },
176
+ "metadata": {},
177
+ "output_type": "display_data"
178
+ },
179
+ {
180
+ "name": "stdout",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "Dataset json downloaded and prepared to /home/dat/.cache/huggingface/datasets/json/default-00e4c1e272015fdb/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723. Subsequent calls will reuse this data.\n"
184
+ ]
185
+ }
186
+ ],
187
+ "source": [
188
+ "#59G c4_cleaned compressed\n",
189
+ "#937M nrc_uniq_cleaned_20210223 compressed\n",
190
+ "#410M nu_uniq_cleaned_20210225 compressed\n",
191
+ "#9.9G oscar_nl_cleaned compressed\n",
192
+ "\n",
193
+ "\n",
194
+ "\n",
195
+ "data_files = []\n",
196
+ "SEED=42\n",
197
+ "def add_jsonlines_dir(path,filespec):\n",
198
+ " global data_files\n",
199
+ " data_files += glob.glob(f\"{path}/{filespec}\")\n",
200
+ " print(f\"Number of files {len(data_files)} after adding {path}\")\n",
201
+ " \n",
202
+ "#add_jsonlines_dir(\"/home/dat/subset_c4_cleannl\",\"*.gz\") \n",
203
+ "add_jsonlines_dir(\"/data/oscar_nl_cleaned\",\"*.gz\")\n",
204
+ "#add_jsonlines_dir(\"/data/nrc_cleaned_idtextfmt\",\"*.gz\")\n",
205
+ "#add_jsonlines_dir(\"/data/nu_cleaned_idtextfmt\",\"*.gz\")\n",
206
+ "random.Random(SEED).shuffle(data_files)\n",
207
+ "total = len(data_files)\n",
208
+ "val_size = int(0.05 * total)\n",
209
+ "train_size = total - val_size\n",
210
+ "print(f\"95%: {train_size}\")\n",
211
+ "train = data_files[:train_size]\n",
212
+ "val = data_files[train_size:]\n",
213
+ "print(f\"Got {len(train)} training files and {len(val)} validation files\")\n",
214
+ "assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
215
+ "datasets = load_dataset('json', data_files={'train': train, 'validation': val})\n",
216
+ "\n",
217
+ "\n",
218
+ "assert list(set(train) & set(val)) == [], 'train overlaps with test'\n"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 41,
224
+ "id": "865a9642",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "dataset_iterator = iter(datasets['train'])"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 78,
234
+ "id": "523b0fc2",
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "name": "stdout",
239
+ "output_type": "stream",
240
+ "text": [
241
+ "Zo stel ik het me voor. Tegen iedere conventie in. Och wat heeft de burgerij gemopperd en schande gesproken. Dat was in die dagen. Nu nog steeds, maar anders. Daarover later meer. En wat zullen ze van u gehouden hebben in de kleine kring van liefhebbers.\n",
242
+ "Jaren geleden, toen ik nog op de academie zat bestudeerde ik uw werk. Vooral de paar overgebleven foto’s van uw Merzbau in Hannover troffen mij. Zo vrij en swingend en onconventioneel.\n",
243
+ "Ze werden opgeslagen in een afgelegen kamer in mijn geheugen, want eigentijdse choreografen en filmmakers en schilders uit de vroege renaissance vroegen om voorrang.\n",
244
+ "Toen u het huis van uw ouders in Hannover betrok transformeerde u acht kamers tot een betoverende sculptuur. Merzbau! Kathedrale des erotischen Elend.\n",
245
+ "In abstracte vlakken en vormen kruipen de volumes chaotisch omhoog langs de muren. Meestal wit. Er vormen zich ruimtes en grotachtige structuren. Hier en daar een typografisch detail of een herkenbaar object, dat uit zijn context geslingerd, vooral vragen oproept. Met hier en daar een antwoord of een vermoeden daarvan.\n",
246
+ "Soms verborg u zich in het kleine orgelkamertje bovenin als er gasten kwamen, om de reactie op hun gezichten te lezen als ze uw gedichten of het karnavals-achtige nummer Du lieber Augustin door de fantastische ruimte hoorden schallen, een lied vol humor en boerse middeleeuwse wreedheid, maar ook melancholie.\n",
247
+ "Banale liedjes laten horen in een ruimte die verschillende betekenissen kan hebben. Ik herken dat zo. Wij deden dat ook in het theater.\n",
248
+ "Ik vraag nu toch uw hand, zo’n beetje dwars door de tijd, om een paar pirouettes te draaien of misschien beter een twist.\n",
249
+ "Het gewicht van de tapijten of het zeil waaronder ik zowat bezwijk, de inspanning om hoog in de opstelling een klosje op te hangen… Op een gegeven moment raak ik in een staat waarin ik niet meer nadenk. Dan doe ik de ingreep die een beeld uiteindelijk af maakt. Grappig niet?\n",
250
+ "Ik vermoed dat u dat ook heeft, dat zware fysieke werken aan Merzbau; dat dat fijn is, dat het zo echt is daardoor en dat je uiteindelijk in trance raakt.\n",
251
+ "Daar leefde u van werken in opdracht; portretten en landschappen. Beeldschoon werk, maar u deed niet anders dan erop mopperen.\n",
252
+ "Ondertussen begon u een nieuwe Merzbau in een schuur op het platteland. U groef er een verdieping onder en begon daar te merzen. Weer die zware fysieke arbeid. Dat beschouwde u als uw echte werk. Daar legde u ‘connecties tussen alles in uw wereld’, al uw werk ‘een levenslange ervaring’.\n",
253
+ "Maar uw landschappen hoorden daar niet bij. Dat is nu vreemd, jammer zelfs. Tenminste, gezien vanuit mijn perspectief, vanuit het heden. Ze komen immers uit dezelfde bron. Is het omdat ze niet abstract zijn?\n",
254
+ "Per Kirkeby is een beroemd Deens schilder en beeldhouwer, graficus en dichter. Nu tachtig jaar oud. U zou hem weten te waarderen. Ook niet binnen een -isme te vangen. Hij heeft heel mooi over zuivere en onzuivere kunst gesproken. Dit klinkt een beetje eng maar ging over zuiver in de zin van kaal en zonder betekenis en in het onzuivere zaten alle associaties en verwijzingen.\n",
255
+ "In míjn werk houd ik van de associaties en verwijzingen. Maar we leven nu in een andere tijd. Pure abstractie wordt zeker nog gevierd door sommige kunstenaars, en zeker niet de minsten, maar de revolutie die het in uw tijd ontketende is uitgewoed.\n",
256
+ "Ik houd ervan dat in mijn werk niks helemaal lijkt te kloppen, maar er is wel samenhang. De objecten zijn volgens een innerlijke logica gekozen. Maar het mag geen surealisme worden. Daar houd ik niet van. Het is een smalle marge waarin ze mogen bestaan.\n",
257
+ "Het gaat vreemd genoeg volgens schilderkunstige principes, al komt er geen verf aan te pas. Ik bouw mijn opstellingen laag voor laag op. Vanuit de achtergrond. Ik doe weg, of bedek wat te makkelijk te duiden is en daarmee het beeld plat slaat, of wat ik te mooi of esthetisch vind. Soms draait het zich om, behoud ik juist wat mooi of betekenisvol is. Ik zet voortdurend voetangels en klemmen voor mijzelf. En ik geloof dat dat de kwaliteit van het werk uitmaakt.\n",
258
+ "Ik vraag me af in hoeverre dit een wet is die voor alle kunst opgaat. Ik geloof het wel. Al gebeurt het soms alleen in het denkproces dat vooraf gaat aan de uitvoering van het werk.\n",
259
+ "Ik ken het in ieder geval heel goed uit mijn theaterwerk. Dat schaven aan een productie tot alle puzzelstukken op hun plaats vallen.\n",
260
+ "Ik kan mij voorstellen dat dat zelfs bij Mondriaan gebeurde. Zijn Victory Boogy Woogy heeft zo iets magisch ongrijpbaars. En toch staan alle vlakken gewoon op hun plek. Daar is zoveel jaar werk voor nodig geweest!\n",
261
+ "In zijn vroege werken, ook landschappen en bomen, proef je wat er allemaal in zit. In die man bedoel ik en in die doeken.\n",
262
+ "Ik wil maar zeggen, die landschappen van u zijn denk ik toch met dezelfde mentaliteit gemaakt als uw dichtwerk of Merzbau. Ze zijn in ieder geval door u gemaakt. Met uw hand, uw geest, uw afwegingen tijdens het schilderen. Dit wel, dit niet.\n",
263
+ "Maar niet mystiek of transcendent? Ik lees in andere bronnen over Dada’s grondslag; Boeddhisme, Taoisme, vroegchristelijke mystici, en over filosofen als Bergson, Nietzsche en Descartes. Nogal tegenstrijdig allemaal.\n",
264
+ "En dat DaDa niets is, dat wil zeggen alles, of het niet-iets, of een vogel op vier poten, of een levensverzekering of een ladder zonder sporten….\n",
265
+ "Ik heb een leven lang studie en kijken en nog eens kijken voor me, om dit alles te doorvorsen. Maar begrijpen doe ik het al. Op m’n intuïtie.\n"
266
+ ]
267
+ }
268
+ ],
269
+ "source": [
270
+ "print(next(dataset_iterator)['text'])"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 31,
276
+ "id": "b5839c79",
277
+ "metadata": {},
278
+ "outputs": [
279
+ {
280
+ "ename": "IndentationError",
281
+ "evalue": "unexpected indent (1021262509.py, line 15)",
282
+ "output_type": "error",
283
+ "traceback": [
284
+ "\u001b[0;36m File \u001b[0;32m\"/tmp/ipykernel_309684/1021262509.py\"\u001b[0;36m, line \u001b[0;32m15\u001b[0m\n\u001b[0;31m train, val = train_val_files()\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unexpected indent\n"
285
+ ]
286
+ }
287
+ ],
288
+ "source": [
289
+ "\n",
290
+ " add_jsonlines_dir(\"/home/dat/subset_c4_cleannl\") \n",
291
+ " add_jsonlines_dir(\"/data/oscar_nl_cleaned\")\n",
292
+ " add_jsonlines_dir(\"/data/nrc_cleaned_idtextfmt\")\n",
293
+ " add_jsonlines_dir(\"/data/nu_cleaned_idtextfmt\")\n",
294
+ " random.Random(SEED).shuffle(data_files)\n",
295
+ " total = len(data_files)\n",
296
+ " val_size = int(0.05 * total)\n",
297
+ " train_size = total - val_size\n",
298
+ " print(f\"95%: {train_size}\")\n",
299
+ " train = data_files\n",
300
+ " val = data_files\n",
301
+ " print(f\"Got {len(train)} training files and {len(val)} validation files\")\n",
302
+ " assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
303
+ " return train, val\n",
304
+ " train, val = train_val_files()\n",
305
+ " datasets = load_dataset('json', data_files={'train': train, 'validation': val})"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": 4,
311
+ "id": "6685589f",
312
+ "metadata": {},
313
+ "outputs": [
314
+ {
315
+ "name": "stdout",
316
+ "output_type": "stream",
317
+ "text": [
318
+ "\n",
319
+ "\n",
320
+ "\n"
321
+ ]
322
+ }
323
+ ],
324
+ "source": [
325
+ "from tokenizers import ByteLevelBPETokenizer\n",
326
+ "tokenizer = ByteLevelBPETokenizer()\n",
327
+ "\n",
328
+ "def batch_iterator(batch_size=1000):\n",
329
+ " for i in range(0, len(datasets), batch_size):\n",
330
+ " yield datasets[\"train\"][i: i + batch_size][\"text\"]\n",
331
+ "\n",
332
+ "tokenizer.train_from_iterator(batch_iterator(), vocab_size=50358, min_frequency=2, special_tokens=[\n",
333
+ " \"<s>\",\n",
334
+ " \"<pad>\",\n",
335
+ " \"</s>\",\n",
336
+ " \"<unk>\",\n",
337
+ " \"<mask>\",\n",
338
+ "])"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 5,
344
+ "id": "5fed49b4",
345
+ "metadata": {},
346
+ "outputs": [
347
+ {
348
+ "data": {
349
+ "text/plain": [
350
+ "39503"
351
+ ]
352
+ },
353
+ "execution_count": 5,
354
+ "metadata": {},
355
+ "output_type": "execute_result"
356
+ }
357
+ ],
358
+ "source": [
359
+ "tokenizer.get_vocab_size()"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 6,
365
+ "id": "69401680",
366
+ "metadata": {},
367
+ "outputs": [
368
+ {
369
+ "name": "stdout",
370
+ "output_type": "stream",
371
+ "text": [
372
+ "/home/dat/pino-roberta-base\n"
373
+ ]
374
+ }
375
+ ],
376
+ "source": [
377
+ "cd ~/pino-roberta-base"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": 7,
383
+ "id": "7a98d754",
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "tokenizer.save(\"tokenizer.json\")"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "id": "e686b9c8",
394
+ "metadata": {},
395
+ "outputs": [
396
+ {
397
+ "name": "stderr",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "Using custom data configuration nl-lang=nl\n"
401
+ ]
402
+ },
403
+ {
404
+ "name": "stdout",
405
+ "output_type": "stream",
406
+ "text": [
407
+ "Downloading and preparing dataset cc100/nl (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/dat/.cache/huggingface/datasets/cc100/nl-lang=nl/0.0.0/b583dd47b0dd43a3c3773075abd993be12d0eee93dbd2cfe15a0e4e94d481e80...\n"
408
+ ]
409
+ },
410
+ {
411
+ "data": {
412
+ "application/vnd.jupyter.widget-view+json": {
413
+ "model_id": "8bb6155775084c42841d5a786a3f014c",
414
+ "version_major": 2,
415
+ "version_minor": 0
416
+ },
417
+ "text/plain": [
418
+ "Downloading: 0%| | 0.00/8.42G [00:00<?, ?B/s]"
419
+ ]
420
+ },
421
+ "metadata": {},
422
+ "output_type": "display_data"
423
+ }
424
+ ],
425
+ "source": [
426
+ "dataset1 = load_dataset(\"mc4\", \"nl\", streaming=True)\n",
427
+ "dataset2 = load_dataset(\"oscar\", \"unshuffled_deduplicated_nl\",streaming=True)\n",
428
+ "dataset3 = load_dataset(\"cc100\", lang=\"nl\")\n",
429
+ "\n"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": 14,
435
+ "id": "1e1498d1",
436
+ "metadata": {},
437
+ "outputs": [
438
+ {
439
+ "name": "stderr",
440
+ "output_type": "stream",
441
+ "text": [
442
+ "INFO:absl:Starting the local TPU driver.\n",
443
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
444
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n",
445
+ "Some weights of FlaxBigBirdModel were not initialized from the model checkpoint at flax-community/pino-roberta-base and are newly initialized: {('pooler', 'kernel'), ('pooler', 'bias')}\n",
446
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
447
+ ]
448
+ }
449
+ ],
450
+ "source": [
451
+ "from transformers import AutoTokenizer, RobertaModel\n",
452
+ "from transformers import BigBirdForSequenceClassification,FlaxBigBirdModel,FlaxBigBirdForMaskedLM\n",
453
+ "\n",
454
+ "model = FlaxBigBirdModel.from_pretrained(\"flax-community/pino-roberta-base\")\n",
455
+ "model.save_pretrained('exported_pytorch_model')"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": null,
461
+ "id": "82f2a9b7",
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": []
465
+ }
466
+ ],
467
+ "metadata": {
468
+ "kernelspec": {
469
+ "display_name": "Python 3 (ipykernel)",
470
+ "language": "python",
471
+ "name": "python3"
472
+ },
473
+ "language_info": {
474
+ "codemirror_mode": {
475
+ "name": "ipython",
476
+ "version": 3
477
+ },
478
+ "file_extension": ".py",
479
+ "mimetype": "text/x-python",
480
+ "name": "python",
481
+ "nbconvert_exporter": "python",
482
+ "pygments_lexer": "ipython3",
483
+ "version": "3.8.10"
484
+ }
485
+ },
486
+ "nbformat": 4,
487
+ "nbformat_minor": 5
488
+ }
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BigBirdForPreTraining"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "attention_type": "block_sparse",
7
+ "block_size": 128,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "gradient_checkpointing": false,
11
+ "hidden_act": "gelu_new",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 4096,
18
+ "model_type": "big_bird",
19
+ "num_attention_heads": 12,
20
+ "num_hidden_layers": 12,
21
+ "num_random_blocks": 3,
22
+ "pad_token_id": 0,
23
+ "position_embedding_type": "absolute",
24
+ "rescale_embeddings": false,
25
+ "sep_token_id": 66,
26
+ "transformers_version": "4.9.0.dev0",
27
+ "type_vocab_size": 2,
28
+ "use_bias": true,
29
+ "use_cache": true,
30
+ "vocab_size": 50358
31
+ }
run.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ export TOKENIZERS_PARALLELISM=0
4
+
5
+ python ./run_mlm_flax.py \
6
+ --push_to_hub \
7
+ --output_dir="./" \
8
+ --model_type="big_bird" \
9
+ --config_name="./" \
10
+ --tokenizer_name="./" \
11
+ --max_seq_length="4096" \
12
+ --weight_decay="0.0095" \
13
+ --warmup_steps="5000" \
14
+ --overwrite_output_dir \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --logging_steps="500" \
18
+ --eval_steps="92768" \
19
+ --num_train_epochs="5" \
20
+ --preprocessing_num_workers="64" \
21
+ --save_steps="20000" \
22
+ --adafactor \
23
+ --learning_rate="5e-5" \
24
+ --per_device_train_batch_size="2" \
25
+ --per_device_eval_batch_size="2" \
26
+ --save_total_limit="5"\
27
+ --dtype="bfloat16" \
28
+ #--resume_from_checkpoint="./"\
29
+ #--gradient_accumulation_steps="4" \
30
+
run_mlm_flax.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import shutil
24
+ import logging
25
+ import os
26
+ import sys
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from ast import Str
30
+
31
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
32
+ from pathlib import Path
33
+ from typing import Dict, List, Optional, Tuple
34
+
35
+ import numpy as np
36
+ from datasets import load_dataset
37
+ from tqdm import tqdm
38
+
39
+ import flax
40
+ import jax
41
+ import jax.numpy as jnp
42
+ import optax
43
+ from flax import jax_utils, traverse_util
44
+ from flax.training import train_state
45
+ from flax.training.common_utils import get_metrics, onehot, shard
46
+ from transformers import (
47
+ CONFIG_MAPPING,
48
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
49
+ AutoConfig,
50
+ AutoTokenizer,
51
+ FlaxAutoModelForMaskedLM,
52
+ HfArgumentParser,
53
+ PreTrainedTokenizerBase,
54
+ TensorType,
55
+ TrainingArguments,
56
+ is_tensorboard_available,
57
+ set_seed,
58
+ )
59
+ from transformers.testing_utils import CaptureLogger
60
+ from flax.serialization import to_bytes, from_bytes
61
+ from importlib.util import find_spec
62
+ from flax.training import checkpoints
63
+ from flax.jax_utils import unreplicate
64
+ from flax.training.checkpoints import save_checkpoint, restore_checkpoint
65
+ import json
66
+
67
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
68
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
69
+
70
+
71
+ @dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
75
+ """
76
+
77
+ model_name_or_path: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": "The model checkpoint for weights initialization."
81
+ "Don't set if you want to train a model from scratch."
82
+ },
83
+ )
84
+ model_type: Optional[str] = field(
85
+ default=None,
86
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
87
+ )
88
+ config_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
90
+ )
91
+ tokenizer_name: Optional[str] = field(
92
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
93
+ )
94
+ cache_dir: Optional[str] = field(
95
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
96
+ )
97
+ use_fast_tokenizer: bool = field(
98
+ default=True,
99
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
100
+ )
101
+ dtype: Optional[str] = field(
102
+ default="float32",
103
+ metadata={
104
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
105
+ },
106
+ )
107
+
108
+
109
+
110
+
111
+ @dataclass
112
+ class DataTrainingArguments:
113
+ """
114
+ Arguments pertaining to what data we are going to input our model for training and eval.
115
+ """
116
+
117
+ dataset_name: Optional[str] = field(
118
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
119
+ )
120
+ dataset_config_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
122
+ )
123
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
124
+ validation_file: Optional[str] = field(
125
+ default=None,
126
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
127
+ )
128
+ train_ref_file: Optional[str] = field(
129
+ default=None,
130
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
131
+ )
132
+ validation_ref_file: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
135
+ )
136
+ overwrite_cache: bool = field(
137
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
138
+ )
139
+ validation_split_percentage: Optional[int] = field(
140
+ default=5,
141
+ metadata={
142
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
143
+ },
144
+ )
145
+ max_seq_length: Optional[int] = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
149
+ "than this will be truncated. Default to the max input length of the model."
150
+ },
151
+ )
152
+ preprocessing_num_workers: Optional[int] = field(
153
+ default=None,
154
+ metadata={"help": "The number of processes to use for the preprocessing."},
155
+ )
156
+ mlm_probability: float = field(
157
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
158
+ )
159
+ pad_to_max_length: bool = field(
160
+ default=False,
161
+ metadata={
162
+ "help": "Whether to pad all samples to `max_seq_length`. "
163
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
164
+ },
165
+ )
166
+ line_by_line: bool = field(
167
+ default=False,
168
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
169
+ )
170
+
171
+
172
+ @flax.struct.dataclass
173
+ class FlaxDataCollatorForLanguageModeling:
174
+ """
175
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
176
+ are not all of the same length.
177
+
178
+ Args:
179
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
180
+ The tokenizer used for encoding the data.
181
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
182
+ The probability with which to (randomly) mask tokens in the input.
183
+
184
+ .. note::
185
+
186
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
187
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
188
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
189
+ argument :obj:`return_special_tokens_mask=True`.
190
+ """
191
+
192
+ tokenizer: PreTrainedTokenizerBase
193
+ mlm_probability: float = 0.15
194
+
195
+ def __post_init__(self):
196
+ if self.tokenizer.mask_token is None:
197
+ raise ValueError(
198
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
199
+ "You should pass `mlm=False` to train on causal language modeling instead."
200
+ )
201
+
202
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
203
+ # Handle dict or lists with proper padding and conversion to tensor.
204
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
205
+
206
+ # If special token mask has been preprocessed, pop it from the dict.
207
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
208
+
209
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
210
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
211
+ )
212
+ return batch
213
+
214
+ def mask_tokens(
215
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
216
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
217
+ """
218
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
219
+ """
220
+ labels = inputs.copy()
221
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
222
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
223
+ special_tokens_mask = special_tokens_mask.astype("bool")
224
+
225
+ probability_matrix[special_tokens_mask] = 0.0
226
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
227
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
228
+
229
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
230
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
231
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
232
+
233
+ # 10% of the time, we replace masked input tokens with random word
234
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
235
+ indices_random &= masked_indices & ~indices_replaced
236
+
237
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
238
+ inputs[indices_random] = random_words[indices_random]
239
+
240
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
241
+ return inputs, labels
242
+
243
+
244
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
245
+ num_samples = len(samples_idx)
246
+ samples_to_remove = num_samples % batch_size
247
+
248
+ if samples_to_remove != 0:
249
+ samples_idx = samples_idx[:-samples_to_remove]
250
+ sections_split = num_samples // batch_size
251
+ batch_idx = np.split(samples_idx, sections_split)
252
+ return batch_idx
253
+
254
+
255
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
256
+ summary_writer.scalar("train_time", train_time, step)
257
+
258
+ train_metrics = get_metrics(train_metrics)
259
+ for key, vals in train_metrics.items():
260
+ tag = f"train_{key}"
261
+ for i, val in enumerate(vals):
262
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
263
+
264
+
265
+ def write_eval_metric(summary_writer, eval_metrics, step):
266
+ for metric_name, value in eval_metrics.items():
267
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
268
+
269
+ def mb_item(x):
270
+ return x.item() if hasattr(x, "item") else x
271
+
272
+ #checkpoint functions
273
+
274
+
275
+
276
+
277
+
278
+ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
279
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
280
+ # TODO: what to remove is decided using step number only, we might want to improve that
281
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
282
+ # sort checkpoints by step
283
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
284
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
285
+ for ckpt in ckpts_to_delete:
286
+ logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
287
+ shutil.rmtree(ckpt)
288
+
289
+
290
+
291
+
292
+
293
+
294
+
295
+ if __name__ == "__main__":
296
+ # See all possible arguments in src/transformers/training_args.py
297
+ # or by passing the --help flag to this script.
298
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
299
+
300
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
301
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
302
+ # If we pass only one argument to the script and it's the path to a json file,
303
+ # let's parse it to get our arguments.
304
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
305
+ else:
306
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
307
+
308
+ if (
309
+ os.path.exists(training_args.output_dir)
310
+ and os.listdir(training_args.output_dir)
311
+ and training_args.do_train
312
+ and not training_args.overwrite_output_dir
313
+ ):
314
+ raise ValueError(
315
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
316
+ "Use --overwrite_output_dir to overcome."
317
+ )
318
+
319
+ # Setup logging
320
+ logging.basicConfig(
321
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
322
+ level="NOTSET",
323
+ datefmt="[%X]",
324
+ )
325
+
326
+ # Log on each process the small summary:
327
+ logger = logging.getLogger(__name__)
328
+
329
+ # Set the verbosity to info of the Transformers logger (on main process only):
330
+ logger.info(f"Training/evaluation parameters {training_args}")
331
+
332
+ # Set seed before initializing model.
333
+ set_seed(training_args.seed)
334
+
335
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
336
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
337
+ # (the dataset will be downloaded automatically from the datasets Hub).
338
+ #
339
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
340
+ # 'text' is found. You can easily tweak this behavior (see below).
341
+ #
342
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
343
+ # download the dataset.
344
+ if data_args.dataset_name is not None:
345
+ # Downloading and loading a dataset from the hub.
346
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
347
+
348
+ if "validation" not in datasets.keys():
349
+ datasets["validation"] = load_dataset(
350
+ data_args.dataset_name,
351
+ data_args.dataset_config_name,
352
+ split=f"train[:{data_args.validation_split_percentage}%]",
353
+ cache_dir=model_args.cache_dir,
354
+ )
355
+ datasets["train"] = load_dataset(
356
+ data_args.dataset_name,
357
+ data_args.dataset_config_name,
358
+ split=f"train[{data_args.validation_split_percentage}%:]",
359
+ cache_dir=model_args.cache_dir,
360
+ )
361
+ else:
362
+ #data_files = {}
363
+ #if data_args.train_file is not None:
364
+ # data_files["train"] = data_args.train_file
365
+ #if data_args.validation_file is not None:
366
+ # data_files["validation"] = data_args.validation_file
367
+ #extension = data_args.train_file.split(".")[-1]
368
+ #if extension == "txt":
369
+ # extension = "text"
370
+ #datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
371
+
372
+ #data_dir = "/home/yeb"
373
+ # data_dir = "/home/yeb/Developer/data"
374
+ data_files = []
375
+ def train_val_files():
376
+ import glob
377
+ import random
378
+ SEED = 42
379
+ def add_jsonlines_dir(path):
380
+ global data_files
381
+ data_files += glob.glob(f"{path}/*.gz")
382
+
383
+ add_jsonlines_dir("/home/dat/subset_c4_cleannl")
384
+ add_jsonlines_dir("/data/oscar_nl_cleaned")
385
+ add_jsonlines_dir("/data/nrc_cleaned_idtextfmt")
386
+ add_jsonlines_dir("/data/nu_cleaned_idtextfmt")
387
+ random.Random(SEED).shuffle(data_files)
388
+ total = len(data_files)
389
+ val_size = int(0.05 * total)
390
+ train_size = total - val_size
391
+ print(f"95%: {train_size}")
392
+ train = data_files[:train_size]
393
+ val = data_files[train_size:]
394
+ print(f"Got {len(train)} training files and {len(val)} validation files")
395
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
396
+ return train, val
397
+ train, val = train_val_files()
398
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val})
399
+ datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
400
+ datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
401
+
402
+
403
+
404
+
405
+ if model_args.config_name:
406
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
407
+ elif model_args.model_name_or_path:
408
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
409
+ else:
410
+ config = CONFIG_MAPPING[model_args.model_type]()
411
+ logger.warning("You are instantiating a new config instance from scratch.")
412
+
413
+ if model_args.tokenizer_name:
414
+ tokenizer = AutoTokenizer.from_pretrained(
415
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
416
+ )
417
+ elif model_args.model_name_or_path:
418
+ tokenizer = AutoTokenizer.from_pretrained(
419
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
420
+ )
421
+ else:
422
+ raise ValueError(
423
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
424
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
425
+ )
426
+
427
+ # Preprocessing the datasets.
428
+ # First we tokenize all the texts.
429
+ if training_args.do_train:
430
+ column_names = datasets["train"].column_names
431
+ else:
432
+ column_names = datasets["validation"].column_names
433
+ text_column_name = "text" if "text" in column_names else column_names[0]
434
+
435
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
436
+
437
+
438
+ if data_args.line_by_line:
439
+ # When using line_by_line, we just tokenize each nonempty line.
440
+ padding = "max_length" if data_args.pad_to_max_length else False
441
+
442
+ def tokenize_function(examples):
443
+ # Remove empty lines
444
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
445
+ return tokenizer(
446
+ examples,
447
+ return_special_tokens_mask=True,
448
+ padding=padding,
449
+ truncation=True,
450
+ max_length=max_seq_length,
451
+ )
452
+
453
+ tokenized_datasets = datasets.map(
454
+ tokenize_function,
455
+ input_columns=[text_column_name],
456
+ batched=True,
457
+ num_proc=data_args.preprocessing_num_workers,
458
+ remove_columns=column_names,
459
+ load_from_cache_file=not data_args.overwrite_cache,
460
+ )
461
+
462
+ else:
463
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
464
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
465
+ # efficient when it receives the `special_tokens_mask`.
466
+ def tokenize_function(examples):
467
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
468
+
469
+ tokenized_datasets = datasets.map(
470
+ tokenize_function,
471
+ batched=True,
472
+ num_proc=data_args.preprocessing_num_workers,
473
+ remove_columns=column_names,
474
+ load_from_cache_file=not data_args.overwrite_cache,
475
+ )
476
+
477
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
478
+ # max_seq_length.
479
+ def group_texts(examples):
480
+ # Concatenate all texts.
481
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
482
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
483
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
484
+ # customize this part to your needs.
485
+ if total_length >= max_seq_length:
486
+ total_length = (total_length // max_seq_length) * max_seq_length
487
+ # Split by chunks of max_len.
488
+ result = {
489
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
490
+ for k, t in concatenated_examples.items()
491
+ }
492
+ return result
493
+
494
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
495
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
496
+ # might be slower to preprocess.
497
+ #
498
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
499
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
500
+ lm_datasets = tokenized_datasets.map(
501
+ group_texts,
502
+ batched=True,
503
+ batch_size=100,
504
+ num_proc=data_args.preprocessing_num_workers,
505
+ load_from_cache_file=not data_args.overwrite_cache,
506
+ )
507
+ train_dataset = lm_datasets["train"]
508
+ eval_dataset = lm_datasets["validation"]
509
+
510
+
511
+
512
+
513
+ # Enable tensorboard only on the master node
514
+ has_tensorboard = is_tensorboard_available()
515
+ if has_tensorboard and jax.process_index() == 0:
516
+ try:
517
+ from flax.metrics.tensorboard import SummaryWriter
518
+
519
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
520
+ except ImportError as ie:
521
+ has_tensorboard = False
522
+ logger.warning(
523
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
524
+ )
525
+ else:
526
+ logger.warning(
527
+ "Unable to display metrics through TensorBoard because the package is not installed: "
528
+ "Please run pip install tensorboard to enable."
529
+ )
530
+ # enable wandb tracking
531
+ has_wandb = find_spec("wandb") is not None
532
+ if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
533
+ try:
534
+ import wandb
535
+ wandb.init(
536
+ entity="wandb",
537
+ project="hf-flax-pino-roberta",
538
+ sync_tensorboard=True
539
+ )
540
+ wandb.config.update(training_args)
541
+ wandb.config.update(model_args)
542
+ wandb.config.update(data_args)
543
+ except ImportError as e:
544
+ print(e)
545
+ has_wandb = False
546
+
547
+ # Data collator
548
+ # This one will take care of randomly masking the tokens.
549
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
550
+
551
+ # Initialize our training
552
+ rng = jax.random.PRNGKey(training_args.seed)
553
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
554
+
555
+ if model_args.model_name_or_path:
556
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
557
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
558
+ )
559
+ else:
560
+ model = FlaxAutoModelForMaskedLM.from_config(
561
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
562
+ )
563
+
564
+ # Store some constant
565
+ num_epochs = int(training_args.num_train_epochs)
566
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
567
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
568
+
569
+ num_train_steps = len(train_dataset) // train_batch_size * num_epochs
570
+
571
+ # Create learning rate schedule
572
+ warmup_fn = optax.linear_schedule(
573
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
574
+ )
575
+ decay_fn = optax.linear_schedule(
576
+ init_value=training_args.learning_rate,
577
+ end_value=0,
578
+ transition_steps=num_train_steps - training_args.warmup_steps,
579
+ )
580
+ linear_decay_lr_schedule_fn = optax.join_schedules(
581
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
582
+ )
583
+
584
+ # We use Optax's "masking" functionality to not apply weight decay
585
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
586
+ # mask boolean with the same structure as the parameters.
587
+ # The mask is True for parameters that should be decayed.
588
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
589
+ # For other models, one should correct the layer norm parameter naming
590
+ # accordingly.
591
+ def decay_mask_fn(params):
592
+ flat_params = traverse_util.flatten_dict(params)
593
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
594
+ return traverse_util.unflatten_dict(flat_mask)
595
+
596
+ # create adam optimizer
597
+ if training_args.adafactor:
598
+ # We use the default parameters here to initialize adafactor,
599
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
600
+ optimizer = optax.adafactor(
601
+ learning_rate=linear_decay_lr_schedule_fn,
602
+ )
603
+ else:
604
+ optimizer = optax.adamw(
605
+ learning_rate=linear_decay_lr_schedule_fn,
606
+ b1=training_args.adam_beta1,
607
+ b2=training_args.adam_beta2,
608
+ eps=training_args.adam_epsilon,
609
+ weight_decay=training_args.weight_decay,
610
+ mask=decay_mask_fn,
611
+ )
612
+
613
+ if training_args.gradient_accumulation_steps > 1:
614
+ optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
615
+ grad_accum_steps = training_args.gradient_accumulation_steps
616
+
617
+ # Setup train state
618
+
619
+
620
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
621
+
622
+ if training_args.resume_from_checkpoint:
623
+ state = restore_checkpoint(training_args.resume_from_checkpoint, state)
624
+ resume_step = mb_item(state.step.item())
625
+ else:
626
+ resume_step = 0
627
+
628
+
629
+ # Define gradient update step fn
630
+ def train_step(state, batch, dropout_rng):
631
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
632
+
633
+ def loss_fn(params):
634
+ labels = batch.pop("labels")
635
+
636
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
637
+
638
+ # compute loss, ignore padded input tokens
639
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
640
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
641
+
642
+ # take average
643
+ loss = loss.sum() / label_mask.sum()
644
+
645
+ return loss
646
+
647
+ grad_fn = jax.value_and_grad(loss_fn)
648
+ loss, grad = grad_fn(state.params)
649
+ grad = jax.lax.pmean(grad, "batch")
650
+ new_state = state.apply_gradients(grads=grad)
651
+
652
+ metrics = jax.lax.pmean(
653
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch"
654
+ )
655
+
656
+ return new_state, metrics, new_dropout_rng
657
+
658
+ # Create parallel version of the train step
659
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
660
+
661
+ # Define eval fn
662
+ def eval_step(params, batch):
663
+ labels = batch.pop("labels")
664
+
665
+ logits = model(**batch, params=params, train=False)[0]
666
+
667
+ # compute loss, ignore padded input tokens
668
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
669
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
670
+
671
+ # compute accuracy
672
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
673
+
674
+ # summarize metrics
675
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
676
+ metrics = jax.lax.psum(metrics, axis_name="batch")
677
+
678
+ return metrics
679
+
680
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
681
+
682
+ # Replicate the train state on each device
683
+ state = jax_utils.replicate(state)
684
+
685
+ train_time = 0
686
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
687
+ for epoch in epochs:
688
+ # ======================== Training ================================
689
+ train_start = time.time()
690
+ train_metrics = []
691
+
692
+ # Create sampling rng
693
+ rng, input_rng = jax.random.split(rng)
694
+ steps_per_epoch = len(train_dataset) // train_batch_size
695
+
696
+ # Generate an epoch by shuffling sampling indices from the train dataset
697
+ num_train_samples = len(train_dataset)
698
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
699
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
700
+
701
+ # Gather the indexes for creating the batch and do a training step
702
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)):
703
+ samples = [train_dataset[int(idx)] for idx in batch_idx]
704
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
705
+
706
+
707
+ # Model forward
708
+ model_inputs = shard(model_inputs.data)
709
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
710
+ train_metrics.append(train_metric)
711
+
712
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
713
+ if cur_step < resume_step:
714
+ continue
715
+
716
+ if (cur_step % training_args.logging_steps * grad_accum_steps) == 0 and cur_step > 0:
717
+ # Save metrics
718
+ train_metric = jax_utils.unreplicate(train_metric)
719
+ train_time += time.time() - train_start
720
+ if has_tensorboard and jax.process_index() == 0:
721
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
722
+ if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
723
+ # TODO: add accumulation of metrics
724
+ _metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
725
+ wandb.log({"training_step":cur_step, **_metrics}, commit=True)
726
+
727
+ epochs.write(
728
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
729
+ )
730
+
731
+ train_metrics = []
732
+
733
+ if cur_step % (training_args.eval_steps * grad_accum_steps) == 0 and cur_step > 0:
734
+ # ======================== Evaluating ==============================
735
+ num_eval_samples = len(eval_dataset)
736
+ eval_samples_idx = jnp.arange(num_eval_samples)
737
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
738
+
739
+ eval_metrics = []
740
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
741
+ samples = [eval_dataset[int(idx)] for idx in batch_idx]
742
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
743
+
744
+ # Model forward
745
+ model_inputs = shard(model_inputs.data)
746
+ metrics = p_eval_step(state.params, model_inputs)
747
+ eval_metrics.append(metrics)
748
+
749
+ # normalize eval metrics
750
+ eval_metrics = get_metrics(eval_metrics)
751
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
752
+ eval_normalizer = eval_metrics.pop("normalizer")
753
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
754
+
755
+ # Update progress bar
756
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
757
+
758
+ # Save metrics
759
+ if has_tensorboard and jax.process_index() == 0:
760
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
761
+
762
+ if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
763
+ _metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
764
+ wandb.log({"eval_step":cur_step, **_metrics})
765
+
766
+ if (cur_step % training_args.save_steps == 0 * grad_accum_steps) and cur_step > 0:
767
+ # save checkpoint after each epoch and push checkpoint to the hub
768
+ if jax.process_index() == 0:
769
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
770
+ model.save_pretrained(
771
+ training_args.output_dir,
772
+ params=params,
773
+ push_to_hub=training_args.push_to_hub,
774
+ commit_message=f"Saving weights and logs of step {cur_step}",
775
+ )
776
+ save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
777
+ if training_args.save_total_limit is not None:
778
+ rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
779
+
780
+ if jax.process_index() == 0:
781
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
782
+ model.save_pretrained(
783
+ training_args.output_dir,
784
+ params=params,
785
+ push_to_hub=training_args.push_to_hub,
786
+ commit_message=f"Saving weights and logs of step {cur_step}",
787
+ )
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff