dat commited on
Commit
f291f93
1 Parent(s): f6e0bf7

Saving weights and logs at step 1252

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Load data & train tokenizer.ipynb +0 -0
  2. checkpoint_60000 +3 -0
  3. events.out.tfevents.1626173264.t1v-n-f5c06ea1-w-0.340852.3.v2 +3 -0
  4. events.out.tfevents.1626174131.t1v-n-f5c06ea1-w-0.343920.3.v2 +3 -0
  5. events.out.tfevents.1626174670.t1v-n-f5c06ea1-w-0.346512.3.v2 +3 -0
  6. events.out.tfevents.1626175237.t1v-n-f5c06ea1-w-0.349243.3.v2 +3 -0
  7. events.out.tfevents.1626176074.t1v-n-f5c06ea1-w-0.351681.3.v2 +3 -0
  8. events.out.tfevents.1626180467.t1v-n-f5c06ea1-w-0.354027.3.v2 +3 -0
  9. events.out.tfevents.1626180750.t1v-n-f5c06ea1-w-0.355855.3.v2 +3 -0
  10. events.out.tfevents.1626181600.t1v-n-f5c06ea1-w-0.357816.3.v2 +3 -0
  11. events.out.tfevents.1626181889.t1v-n-f5c06ea1-w-0.360037.3.v2 +3 -0
  12. events.out.tfevents.1626182175.t1v-n-f5c06ea1-w-0.362298.3.v2 +3 -0
  13. events.out.tfevents.1626182874.t1v-n-f5c06ea1-w-0.365284.3.v2 +3 -0
  14. events.out.tfevents.1626184460.t1v-n-f5c06ea1-w-0.369028.3.v2 +3 -0
  15. events.out.tfevents.1626242600.t1v-n-f5c06ea1-w-0.491835.3.v2 +3 -0
  16. events.out.tfevents.1626285315.t1v-n-f5c06ea1-w-0.533662.3.v2 +3 -0
  17. events.out.tfevents.1626286793.t1v-n-f5c06ea1-w-0.547087.3.v2 +3 -0
  18. events.out.tfevents.1626287584.t1v-n-f5c06ea1-w-0.550207.3.v2 +3 -0
  19. events.out.tfevents.1626288936.t1v-n-f5c06ea1-w-0.553832.3.v2 +3 -0
  20. events.out.tfevents.1626290714.t1v-n-f5c06ea1-w-0.557554.3.v2 +3 -0
  21. events.out.tfevents.1626292080.t1v-n-f5c06ea1-w-0.560928.3.v2 +3 -0
  22. events.out.tfevents.1626292866.t1v-n-f5c06ea1-w-0.563390.3.v2 +3 -0
  23. events.out.tfevents.1626293250.t1v-n-f5c06ea1-w-0.565261.3.v2 +3 -0
  24. events.out.tfevents.1626294676.t1v-n-f5c06ea1-w-0.568447.3.v2 +3 -0
  25. events.out.tfevents.1626295212.t1v-n-f5c06ea1-w-0.570637.3.v2 +3 -0
  26. events.out.tfevents.1626296457.t1v-n-f5c06ea1-w-0.573688.3.v2 +3 -0
  27. events.out.tfevents.1626296630.t1v-n-f5c06ea1-w-0.575437.3.v2 +3 -0
  28. flax_model.msgpack +2 -2
  29. run.sh +11 -9
  30. run_mlm_flax.py +270 -218
  31. run_mlm_flax_no_accum.py +776 -0
  32. save_tokenized_data.py +484 -0
  33. train_tokenizer.py +43 -0
  34. wandb/debug-internal.log +1 -1
  35. wandb/debug.log +1 -1
  36. wandb/latest-run +1 -1
  37. wandb/run-20210713_010630-14xhiyhf/files/output.log +9 -0
  38. wandb/run-20210713_010630-14xhiyhf/logs/debug-internal.log +24 -0
  39. wandb/run-20210713_010630-14xhiyhf/logs/debug.log +2 -0
  40. wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb +0 -0
  41. wandb/run-20210713_104745-1rl2j7or/files/config.yaml +304 -0
  42. wandb/run-20210713_104745-1rl2j7or/files/output.log +57 -0
  43. wandb/run-20210713_104745-1rl2j7or/files/requirements.txt +92 -0
  44. wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json +44 -0
  45. wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json +1 -0
  46. wandb/run-20210713_104745-1rl2j7or/logs/debug-internal.log +181 -0
  47. wandb/run-20210713_104745-1rl2j7or/logs/debug.log +27 -0
  48. wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb +0 -0
  49. wandb/run-20210713_110212-594z6oo0/files/config.yaml +307 -0
  50. wandb/run-20210713_110212-594z6oo0/files/output.log +39 -0
Load data & train tokenizer.ipynb CHANGED
The diff for this file is too large to render. See raw diff
checkpoint_60000 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e6d7222b2cee297be0891db385dcce6e0cbff6ec3697c08118513955f8aaf7
3
+ size 769729450
events.out.tfevents.1626173264.t1v-n-f5c06ea1-w-0.340852.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73fdfc3eb9d8111b1e3460227717a3942adfe9263bca08b7fd2bfab9af98d9a1
3
+ size 38186
events.out.tfevents.1626174131.t1v-n-f5c06ea1-w-0.343920.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfc6f0b5b354bd4d8d13834613ece71ac9d948186313bc3fde5e2e132a1c9cab
3
+ size 40
events.out.tfevents.1626174670.t1v-n-f5c06ea1-w-0.346512.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f74cf77c0a672ad1201614ba6642a4f3a27b9cf021d0e88eb362c7f38ee86304
3
+ size 40
events.out.tfevents.1626175237.t1v-n-f5c06ea1-w-0.349243.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be5c2acf821fd2ce776ff5e434706cb933a0fa323f0bb1a82dadd832f1f589d4
3
+ size 40
events.out.tfevents.1626176074.t1v-n-f5c06ea1-w-0.351681.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b085d5029d052defe00b26c54b6357e9d05cbc5ad38cdd2f12537ed0b90008d2
3
+ size 441341
events.out.tfevents.1626180467.t1v-n-f5c06ea1-w-0.354027.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:973eec9b2b17e54f3ee35dc0c4b85a4a3ecf5488cb59f5619d7c635641bfe7b6
3
+ size 40
events.out.tfevents.1626180750.t1v-n-f5c06ea1-w-0.355855.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:013fc500b7fdd46262ee2b2ed5a3624249adef426d0b134944080ccf90d363ed
3
+ size 40
events.out.tfevents.1626181600.t1v-n-f5c06ea1-w-0.357816.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3d4a519b8f1c293258e292768822980b487ef0e02bbfe9d6a3132b8c2fdd791
3
+ size 40
events.out.tfevents.1626181889.t1v-n-f5c06ea1-w-0.360037.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c1ed9142ba98f2f7197e2a44361331a8c112af5dba98d7fc9f0bcab6228ae8c
3
+ size 40
events.out.tfevents.1626182175.t1v-n-f5c06ea1-w-0.362298.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29cc2c143c306c4619802094513459dbb71c4730d3cdfb879e7224923ddfe7ea
3
+ size 40
events.out.tfevents.1626182874.t1v-n-f5c06ea1-w-0.365284.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24aa4302db5d02121389fc7f8944025588034aedd21f772c2b71224e3a0b0d13
3
+ size 220634
events.out.tfevents.1626184460.t1v-n-f5c06ea1-w-0.369028.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e5631bf443386a4e37d77053e55ba4517153d5f6d7f77b616258d9c78e6901f
3
+ size 367772
events.out.tfevents.1626242600.t1v-n-f5c06ea1-w-0.491835.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f94f6c2d80b0e0d6247997634649101caefa3ad8ab4f408b529ad38f86c8770
3
+ size 40
events.out.tfevents.1626285315.t1v-n-f5c06ea1-w-0.533662.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29b681f16c441caf85381c9def58d19f4479a2460146d2cfb68991f8327f01fe
3
+ size 40
events.out.tfevents.1626286793.t1v-n-f5c06ea1-w-0.547087.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53d63b11450875138751afac48c611f4da76fadc0affb0ec98896b35dbad9728
3
+ size 40
events.out.tfevents.1626287584.t1v-n-f5c06ea1-w-0.550207.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62cc6dc4bf215d99f8685629bf632f82d65fc7f1127d876ded332b31b5432064
3
+ size 40
events.out.tfevents.1626288936.t1v-n-f5c06ea1-w-0.553832.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fccf6070edac76c190b8bb8de4e37b889dd1b18835777203f9d16ac658aaf71
3
+ size 40
events.out.tfevents.1626290714.t1v-n-f5c06ea1-w-0.557554.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d46028802a38f383ce27081e90ff848e3da863ac08c341f101eed1b20a39556c
3
+ size 40
events.out.tfevents.1626292080.t1v-n-f5c06ea1-w-0.560928.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2e89d0090ae1228c609a140c2a20fbdfb208480a0dd16aced968756947a93f0
3
+ size 147065
events.out.tfevents.1626292866.t1v-n-f5c06ea1-w-0.563390.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b5607707732c41fb3bac9b56702cf2a006ba526d98638e0352ba54e809c6eff
3
+ size 40
events.out.tfevents.1626293250.t1v-n-f5c06ea1-w-0.565261.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83bed69057844c7af14e165d87c9678d28135297ab5bd374d1e0d80ebd31966f
3
+ size 221057
events.out.tfevents.1626294676.t1v-n-f5c06ea1-w-0.568447.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:050b6dc69ea5a9946fc01c76d67ea00913117399f1a37e0f24db39f39c52e76f
3
+ size 73565
events.out.tfevents.1626295212.t1v-n-f5c06ea1-w-0.570637.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2818b40b384ff7f5a57fe1c4994ebbd02140f7221904f527cfc0a9a115334a79
3
+ size 184532
events.out.tfevents.1626296457.t1v-n-f5c06ea1-w-0.573688.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df3d8a6aa5b0177a3c337963bad77cc5cea9ed722032941dbac474d03b5a3261
3
+ size 40
events.out.tfevents.1626296630.t1v-n-f5c06ea1-w-0.575437.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:932b70a150d991f6939f853c7b54516d5309f2d6c19761fa96a50999bf2199e7
3
+ size 147993
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:19dddbba6ad2a0aa9c5c22f1b9750b90fcd0b7c8f3007cbd6af9a17d447fa417
3
- size 256576390
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:422812fccdda54c02543ac5e994b33b54e510e0474439fbe9360d5190787d38e
3
+ size 510090043
run.sh CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env bash
2
 
3
- export TOKENIZERS_PARALLELISM=0
4
 
5
  python ./run_mlm_flax.py \
6
  --push_to_hub \
@@ -14,18 +14,20 @@ python ./run_mlm_flax.py \
14
  --overwrite_output_dir \
15
  --adam_beta1="0.9" \
16
  --adam_beta2="0.98" \
17
- --logging_steps="500" \
18
- --eval_steps="92768" \
19
- --num_train_epochs="5" \
20
- --preprocessing_num_workers="64" \
21
- --save_steps="20000" \
22
- --learning_rate="5e-5" \
23
  --per_device_train_batch_size="2" \
24
  --per_device_eval_batch_size="2" \
25
  --save_total_limit="5"\
26
- --gradient_accumulation_steps="2" \
 
 
 
27
  #--adafactor \
28
  #--dtype="bfloat16" \
29
- #--resume_from_checkpoint="./"\
30
 
31
 
1
  #!/usr/bin/env bash
2
 
3
+ #export TOKENIZERS_PARALLELISM=0
4
 
5
  python ./run_mlm_flax.py \
6
  --push_to_hub \
14
  --overwrite_output_dir \
15
  --adam_beta1="0.9" \
16
  --adam_beta2="0.98" \
17
+ --logging_steps="250" \
18
+ --eval_steps="500" \
19
+ --num_train_epochs="3" \
20
+ --preprocessing_num_workers="96" \
21
+ --save_steps="1250" \
22
+ --learning_rate="1e-4" \
23
  --per_device_train_batch_size="2" \
24
  --per_device_eval_batch_size="2" \
25
  --save_total_limit="5"\
26
+ --max_eval_samples="500"\
27
+ --overwrite_cache False \
28
+ --gradient_accumulation_steps="4" \
29
+ #--resume_from_checkpoint="./"\
30
  #--adafactor \
31
  #--dtype="bfloat16" \
 
32
 
33
 
run_mlm_flax.py CHANGED
@@ -20,20 +20,18 @@ text file or a dataset.
20
  Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
  https://huggingface.co/models?filter=masked-lm
22
  """
23
- import shutil
24
  import logging
25
  import os
26
  import sys
27
  import time
28
  from dataclasses import dataclass, field
29
- from ast import Str
30
 
31
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
32
  from pathlib import Path
33
  from typing import Dict, List, Optional, Tuple
34
 
35
  import numpy as np
36
- from datasets import load_dataset
37
  from tqdm import tqdm
38
 
39
  import flax
@@ -56,13 +54,12 @@ from transformers import (
56
  is_tensorboard_available,
57
  set_seed,
58
  )
59
- from transformers.testing_utils import CaptureLogger
60
- from flax.serialization import to_bytes, from_bytes
61
- from importlib.util import find_spec
62
  from flax.training import checkpoints
63
  from flax.jax_utils import unreplicate
64
  from flax.training.checkpoints import save_checkpoint, restore_checkpoint
65
- import json
 
66
 
67
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
68
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -104,8 +101,10 @@ class ModelArguments:
104
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
105
  },
106
  )
107
-
108
-
 
 
109
 
110
 
111
  @dataclass
@@ -120,11 +119,6 @@ class DataTrainingArguments:
120
  dataset_config_name: Optional[str] = field(
121
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
122
  )
123
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
124
- validation_file: Optional[str] = field(
125
- default=None,
126
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
127
- )
128
  train_ref_file: Optional[str] = field(
129
  default=None,
130
  metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
@@ -136,6 +130,9 @@ class DataTrainingArguments:
136
  overwrite_cache: bool = field(
137
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
138
  )
 
 
 
139
  validation_split_percentage: Optional[int] = field(
140
  default=5,
141
  metadata={
@@ -167,6 +164,17 @@ class DataTrainingArguments:
167
  default=False,
168
  metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
169
  )
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  @flax.struct.dataclass
@@ -266,33 +274,73 @@ def write_eval_metric(summary_writer, eval_metrics, step):
266
  for metric_name, value in eval_metrics.items():
267
  summary_writer.scalar(f"eval_{metric_name}", value, step)
268
 
269
- def mb_item(x):
270
- return x.item() if hasattr(x, "item") else x
271
-
272
- #checkpoint functions
273
-
274
-
275
-
276
-
277
-
278
- def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
279
- "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
280
- # TODO: what to remove is decided using step number only, we might want to improve that
281
- ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
282
- # sort checkpoints by step
283
- ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
284
- ckpts_to_delete = ckpts_sorted[:-save_total_limit]
285
- for ckpt in ckpts_to_delete:
286
- logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
287
- shutil.rmtree(ckpt)
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
-
291
- class TrainState(train_state.TrainState):
292
- grad_accum: jnp.ndarray
293
 
294
 
295
-
296
  if __name__ == "__main__":
297
  # See all possible arguments in src/transformers/training_args.py
298
  # or by passing the --help flag to this script.
@@ -360,52 +408,70 @@ if __name__ == "__main__":
360
  cache_dir=model_args.cache_dir,
361
  )
362
  else:
363
- #data_files = {}
364
- #if data_args.train_file is not None:
365
- # data_files["train"] = data_args.train_file
366
- #if data_args.validation_file is not None:
367
- # data_files["validation"] = data_args.validation_file
368
- #extension = data_args.train_file.split(".")[-1]
369
- #if extension == "txt":
370
- # extension = "text"
371
- #datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
372
-
373
- #data_dir = "/home/yeb"
374
- # data_dir = "/home/yeb/Developer/data"
375
  data_files = []
376
- def train_val_files():
377
- import glob
378
- import random
379
- SEED = 42
380
- def add_jsonlines_dir(path):
381
- global data_files
382
- data_files += glob.glob(f"{path}/*.gz")
383
-
384
- add_jsonlines_dir("/home/dat/subset_c4_cleannl")
385
- add_jsonlines_dir("/data/oscar_nl_cleaned")
386
- add_jsonlines_dir("/data/nrc_cleaned_idtextfmt")
387
- add_jsonlines_dir("/data/nu_cleaned_idtextfmt")
388
- random.Random(SEED).shuffle(data_files)
389
- total = len(data_files)
390
- val_size = int(0.05 * total)
391
- train_size = total - val_size
392
- print(f"95%: {train_size}")
393
- train = data_files[:train_size]
394
- val = data_files[train_size:]
395
- print(f"Got {len(train)} training files and {len(val)} validation files")
396
- assert list(set(train) & set(val)) == [], "Train overlaps with test"
397
- return train, val
398
- train, val = train_val_files()
399
- datasets = load_dataset('json', data_files={'train': train, 'validation': val})
400
- datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
401
- datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
402
- #datasets["train"] = datasets["train"].select(range(10000))
403
- #datasets["validation"] = datasets["validation"].select(range(10000))
404
 
 
 
405
 
406
 
 
 
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
 
 
 
 
 
 
 
 
 
409
  if model_args.config_name:
410
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
411
  elif model_args.model_name_or_path:
@@ -430,90 +496,97 @@ if __name__ == "__main__":
430
 
431
  # Preprocessing the datasets.
432
  # First we tokenize all the texts.
433
- if training_args.do_train:
434
- column_names = datasets["train"].column_names
435
- else:
436
- column_names = datasets["validation"].column_names
437
- text_column_name = "text" if "text" in column_names else column_names[0]
438
-
439
- max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- if data_args.line_by_line:
443
- # When using line_by_line, we just tokenize each nonempty line.
444
- padding = "max_length" if data_args.pad_to_max_length else False
445
-
446
- def tokenize_function(examples):
447
- # Remove empty lines
448
- examples = [line for line in examples if len(line) > 0 and not line.isspace()]
449
- return tokenizer(
450
- examples,
451
- return_special_tokens_mask=True,
452
- padding=padding,
453
- truncation=True,
454
- max_length=max_seq_length,
455
  )
456
 
457
- tokenized_datasets = datasets.map(
458
- tokenize_function,
459
- input_columns=[text_column_name],
460
- batched=True,
461
- num_proc=data_args.preprocessing_num_workers,
462
- remove_columns=column_names,
463
- load_from_cache_file=not data_args.overwrite_cache,
464
- )
465
-
466
- else:
467
- # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
468
- # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
469
- # efficient when it receives the `special_tokens_mask`.
470
- def tokenize_function(examples):
471
- return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
472
-
473
- tokenized_datasets = datasets.map(
474
- tokenize_function,
475
- batched=True,
476
- num_proc=data_args.preprocessing_num_workers,
477
- remove_columns=column_names,
478
- load_from_cache_file=not data_args.overwrite_cache,
479
- )
480
-
481
- # Main data processing function that will concatenate all texts from our dataset and generate chunks of
482
- # max_seq_length.
483
- def group_texts(examples):
484
- # Concatenate all texts.
485
- concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
486
- total_length = len(concatenated_examples[list(examples.keys())[0]])
487
- # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
488
- # customize this part to your needs.
489
- if total_length >= max_seq_length:
490
- total_length = (total_length // max_seq_length) * max_seq_length
491
- # Split by chunks of max_len.
492
- result = {
493
- k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
494
- for k, t in concatenated_examples.items()
495
- }
496
- return result
497
-
498
- # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
499
- # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
500
- # might be slower to preprocess.
501
- #
502
- # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
503
- # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
504
- lm_datasets = tokenized_datasets.map(
505
- group_texts,
506
- batched=True,
507
- batch_size=100,
508
- num_proc=data_args.preprocessing_num_workers,
509
- load_from_cache_file=not data_args.overwrite_cache,
510
- )
511
- train_dataset = lm_datasets["train"]
512
- eval_dataset = lm_datasets["validation"]
513
-
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
 
 
516
 
 
517
  # Enable tensorboard only on the master node
518
  has_tensorboard = is_tensorboard_available()
519
  if has_tensorboard and jax.process_index() == 0:
@@ -531,7 +604,6 @@ if __name__ == "__main__":
531
  "Unable to display metrics through TensorBoard because the package is not installed: "
532
  "Please run pip install tensorboard to enable."
533
  )
534
- # enable wandb tracking
535
  has_wandb = find_spec("wandb") is not None
536
  if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
537
  try:
@@ -547,7 +619,6 @@ if __name__ == "__main__":
547
  except ImportError as e:
548
  print(e)
549
  has_wandb = False
550
-
551
  # Data collator
552
  # This one will take care of randomly masking the tokens.
553
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
@@ -567,10 +638,10 @@ if __name__ == "__main__":
567
 
568
  # Store some constant
569
  num_epochs = int(training_args.num_train_epochs)
570
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
571
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
572
 
573
- num_train_steps = len(train_dataset) // train_batch_size * num_epochs
574
 
575
  # Create learning rate schedule
576
  warmup_fn = optax.linear_schedule(
@@ -605,6 +676,7 @@ if __name__ == "__main__":
605
  learning_rate=linear_decay_lr_schedule_fn,
606
  )
607
  else:
 
608
  optimizer = optax.adamw(
609
  learning_rate=linear_decay_lr_schedule_fn,
610
  b1=training_args.adam_beta1,
@@ -613,22 +685,26 @@ if __name__ == "__main__":
613
  weight_decay=training_args.weight_decay,
614
  mask=decay_mask_fn,
615
  )
 
 
 
 
616
 
617
- #if training_args.gradient_accumulation_steps > 1:
618
- # optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
619
- #grad_accum_steps = training_args.gradient_accumulation_steps
620
 
621
  # Setup train state
622
-
623
-
624
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,grad_accum=jax.tree_map(jnp.zeros_like, model.params))
625
-
626
  if training_args.resume_from_checkpoint:
627
- state = restore_checkpoint(training_args.resume_from_checkpoint, state)
628
- resume_step = mb_item(state.step.item())
 
 
629
  else:
630
  resume_step = 0
631
-
632
 
633
  # Define gradient update step fn
634
  def train_step(state, batch, dropout_rng):
@@ -646,30 +722,17 @@ if __name__ == "__main__":
646
  # take average
647
  loss = loss.sum() / label_mask.sum()
648
 
649
- return loss / training_args.gradient_accumulation_steps
650
 
651
  grad_fn = jax.value_and_grad(loss_fn)
652
- loss, grads = grad_fn(state.params)
653
- grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
654
-
655
- def update_fn():
656
- grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
657
- grads = jax.lax.pmean(grad_accum, "batch")
658
- new_state = state.apply_gradients(grads=grads,grad_accum=jax.tree_map(jnp.zeros_like, grads))
659
- return new_state
660
-
661
- new_state = jax.lax.cond(
662
- state.step % training_args.gradient_accumulation_steps == 0,
663
- lambda _: update_fn(),
664
- lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
665
- None,
666
- )
667
-
668
  metrics = jax.lax.pmean(
669
- {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" #
670
  )
671
 
672
- #return new_state.replace(new_dropout_rng=new_dropout_rng), metrics
673
  return new_state, metrics, new_dropout_rng
674
 
675
  # Create parallel version of the train step
@@ -700,7 +763,10 @@ if __name__ == "__main__":
700
  state = jax_utils.replicate(state)
701
 
702
  train_time = 0
703
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
 
 
704
  for epoch in epochs:
705
  # ======================== Training ================================
706
  train_start = time.time()
@@ -708,54 +774,53 @@ if __name__ == "__main__":
708
 
709
  # Create sampling rng
710
  rng, input_rng = jax.random.split(rng)
711
- steps_per_epoch = len(train_dataset) // train_batch_size
712
 
713
  # Generate an epoch by shuffling sampling indices from the train dataset
714
- num_train_samples = len(train_dataset)
715
  train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
716
- train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) #// grad_accum_steps
717
 
718
  # Gather the indexes for creating the batch and do a training step
719
- for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)): #grad_accum
720
- samples = [train_dataset[int(idx)] for idx in batch_idx]
721
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
722
-
723
 
724
  # Model forward
725
  model_inputs = shard(model_inputs.data)
726
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
727
  train_metrics.append(train_metric)
728
 
729
- cur_step = epoch * (num_train_samples // train_batch_size) + step
730
  if cur_step < resume_step:
731
  continue
732
 
733
- if (cur_step % training_args.logging_steps) == 0 and cur_step > 0: # * grad_accum_steps
734
  # Save metrics
735
  train_metric = jax_utils.unreplicate(train_metric)
736
  train_time += time.time() - train_start
737
  if has_tensorboard and jax.process_index() == 0:
738
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
 
739
  if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
740
  # TODO: add accumulation of metrics
741
  _metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
742
  wandb.log({"training_step":cur_step, **_metrics}, commit=True)
743
-
744
  epochs.write(
745
  f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
746
  )
747
 
748
  train_metrics = []
749
 
750
- if cur_step % (training_args.eval_steps) == 0 and cur_step > 0: #* grad_accum_steps
751
  # ======================== Evaluating ==============================
752
- num_eval_samples = len(eval_dataset)
753
  eval_samples_idx = jnp.arange(num_eval_samples)
754
  eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
755
 
756
  eval_metrics = []
757
  for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
758
- samples = [eval_dataset[int(idx)] for idx in batch_idx]
759
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
760
 
761
  # Model forward
@@ -775,30 +840,17 @@ if __name__ == "__main__":
775
  # Save metrics
776
  if has_tensorboard and jax.process_index() == 0:
777
  write_eval_metric(summary_writer, eval_metrics, cur_step)
778
-
779
  if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
780
  _metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
781
  wandb.log({"eval_step":cur_step, **_metrics})
782
 
783
- if (cur_step % training_args.save_steps == 0 ) and cur_step > 0: #
784
  # save checkpoint after each epoch and push checkpoint to the hub
785
  if jax.process_index() == 0:
786
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
787
- model.save_pretrained(
788
- training_args.output_dir,
789
- params=params,
790
- push_to_hub=training_args.push_to_hub,
791
- commit_message=f"Saving weights and logs of step {cur_step}",
792
- )
793
- save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
794
  if training_args.save_total_limit is not None:
795
  rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
796
-
797
  if jax.process_index() == 0:
798
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
799
- model.save_pretrained(
800
- training_args.output_dir,
801
- params=params,
802
- push_to_hub=training_args.push_to_hub,
803
- commit_message=f"Saving weights and logs of step {cur_step}",
804
- )
20
  Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
  https://huggingface.co/models?filter=masked-lm
22
  """
 
23
  import logging
24
  import os
25
  import sys
26
  import time
27
  from dataclasses import dataclass, field
 
28
 
29
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
  from pathlib import Path
31
  from typing import Dict, List, Optional, Tuple
32
 
33
  import numpy as np
34
+ from datasets import load_dataset, DatasetDict
35
  from tqdm import tqdm
36
 
37
  import flax
54
  is_tensorboard_available,
55
  set_seed,
56
  )
57
+ import json
 
 
58
  from flax.training import checkpoints
59
  from flax.jax_utils import unreplicate
60
  from flax.training.checkpoints import save_checkpoint, restore_checkpoint
61
+ from importlib.util import find_spec
62
+
63
 
64
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
101
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
102
  },
103
  )
104
+ save_optimizer: Optional[bool] = field(
105
+ default=True,
106
+ metadata={"help": "Whether to store full train state including optimizer."},
107
+ )
108
 
109
 
110
  @dataclass
119
  dataset_config_name: Optional[str] = field(
120
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
121
  )
 
 
 
 
 
122
  train_ref_file: Optional[str] = field(
123
  default=None,
124
  metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
130
  overwrite_cache: bool = field(
131
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
132
  )
133
+
134
+
135
+
136
  validation_split_percentage: Optional[int] = field(
137
  default=5,
138
  metadata={
164
  default=False,
165
  metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
166
  )
167
+ max_eval_samples: Optional[int] = field(
168
+ default=None,
169
+ metadata={
170
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
171
+ "value if set."
172
+ },
173
+ )
174
+
175
+
176
+
177
+
178
 
179
 
180
  @flax.struct.dataclass
274
  for metric_name, value in eval_metrics.items():
275
  summary_writer.scalar(f"eval_{metric_name}", value, step)
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ def _zeros_tree_like(inp_tree):
279
+ return jax.tree_map(jnp.zeros_like, inp_tree)
280
+
281
+ def fake_update(state):
282
+ fake_updates = _zeros_tree_like(state.params)
283
+ _, new_inner_opt_state = state.tx.inner_opt.update(fake_updates, state.opt_state.inner_opt_state, state.params)
284
+ opt_state = state.opt_state
285
+ new_opt_state = optax.MultiStepsState(mini_step=opt_state.mini_step,
286
+ gradient_step=opt_state.gradient_step,
287
+ inner_opt_state=new_inner_opt_state,
288
+ acc_grads=opt_state.acc_grads)
289
+ return state.replace(opt_state=new_opt_state)
290
+
291
+ def reinstantiate_states(opt_state):
292
+ new_state = []
293
+ for state in opt_state:
294
+ cls = getattr(optax, type(state).__name__)
295
+ new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
296
+ return new_state
297
+
298
+ def restore_model_checkpoint(save_dir, state):
299
+ logger.info(f"RESTORING CHECKPOINT FROM {save_dir}...")
300
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
301
+ params = from_bytes(state.params, f.read())
302
+
303
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
304
+ opt_state = from_bytes(state.opt_state, f.read())
305
+
306
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
307
+ training_state = json.load(f)
308
+ step = training_state["step"]
309
+
310
+ logger.info("checkpoint restored")
311
+ # reinstantiate inner opt state to avoid type conflict
312
+ if hasattr(opt_state, "inner_opt_state"):
313
+ print("restoring state of multisteps optimizer")
314
+ inner_opt_state = reinstantiate_states(opt_state.inner_opt_state)
315
+ ms_state_dict = {k:getattr(state.opt_state, k) for k in state.opt_state._fields}
316
+ ms_state_dict["inner_opt_state"] = inner_opt_state
317
+ opt_state = optax.MultiStepsState(**ms_state_dict)
318
+
319
+ return state.replace(step=step, params=params, opt_state=opt_state)
320
+
321
+ def save_model_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
322
+ """
323
+ If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.
324
+ """
325
+ state = jax_utils.unreplicate(state)
326
+ logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
327
+ if not push_to_hub:
328
+ save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
329
+ model.save_pretrained(
330
+ save_dir,
331
+ params=state.params,
332
+ push_to_hub=push_to_hub,
333
+ commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
334
+ )
335
+ if with_opt:
336
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
337
+ f.write(to_bytes(state.opt_state))
338
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
339
+ json.dump({"step": state.step.item()}, f)
340
+ logger.info("checkpoint saved")
341
 
 
 
 
342
 
343
 
 
344
  if __name__ == "__main__":
345
  # See all possible arguments in src/transformers/training_args.py
346
  # or by passing the --help flag to this script.
408
  cache_dir=model_args.cache_dir,
409
  )
410
  else:
411
+ import glob
412
+ import random
 
 
 
 
 
 
 
 
 
 
413
  data_files = []
414
+ def add_jsonlines_dir(path, filespec):
415
+ global data_files
416
+ data_files += glob.glob(f"{path}/{filespec}")
417
+ data_files = list(set(data_files))
418
+ print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
419
+ add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
420
+ add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
421
+ add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
422
+ random.Random(42).shuffle(data_files)
423
+ total = len(data_files)
424
+ print(total)
425
+ perc = 0.05
426
+ val_size = int(perc * total)
427
+ train_size = total - val_size
428
+ train = data_files[:train_size]
429
+ val = data_files[train_size:]
430
+ print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
431
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
432
+ load_grouped = True
433
+ if not load_grouped:
434
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val})
435
+
436
+ #from datasets import Dataset
 
 
 
 
 
437
 
438
+ #dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-train.arrow")
439
+ #dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-validation.arrow")
440
 
441
 
442
+ def mb_item(x):
443
+ return x.item() if hasattr(x, "item") else x
444
 
445
+ def save_model_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
446
+ """
447
+ If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.
448
+ """
449
+ state = jax_utils.unreplicate(state)
450
+ logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
451
+ if not push_to_hub:
452
+ save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
453
+ model.save_pretrained(
454
+ save_dir,
455
+ params=state.params,
456
+ push_to_hub=push_to_hub,
457
+ commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
458
+ )
459
+ if with_opt:
460
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
461
+ f.write(to_bytes(state.opt_state))
462
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
463
+ json.dump({"step": state.step.item()}, f)
464
+ logger.info("checkpoint saved")
465
 
466
+
467
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
468
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
469
+
470
+ # Load pretrained model and tokenizer
471
+
472
+ # Distributed training:
473
+ # The .from_pretrained methods guarantee that only one local process can concurrently
474
+ # download model & vocab.
475
  if model_args.config_name:
476
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
477
  elif model_args.model_name_or_path:
496
 
497
  # Preprocessing the datasets.
498
  # First we tokenize all the texts.
 
 
 
 
 
 
 
499
 
500
+ if load_grouped:
501
+ logger.info("Loading tokenized and grouped dataset")
502
+ tokenized_datasets = DatasetDict.load_from_disk("/data/tokenized_data")
503
+ logger.info("Setting max validation examples to ")
504
+ print(f"Number of validation examples {data_args.max_eval_samples}")
505
+ tokenized_datasets["train"]= tokenized_datasets["train"].select(range(20000))
506
+ if data_args.max_eval_samples is not None:
507
+ tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(data_args.max_eval_samples))
508
+ else:
509
+ if training_args.do_train:
510
+ column_names = datasets["train"].column_names
511
+ else:
512
+ column_names = datasets["validation"].column_names
513
+ text_column_name = "text" if "text" in column_names else column_names[0]
514
+
515
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
516
+
517
+ if data_args.line_by_line:
518
+ # When using line_by_line, we just tokenize each nonempty line.
519
+ padding = "max_length" if data_args.pad_to_max_length else False
520
+
521
+ def tokenize_function(examples):
522
+ # Remove empty lines
523
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
524
+ return tokenizer(
525
+ examples,
526
+ return_special_tokens_mask=True,
527
+ padding=padding,
528
+ truncation=True,
529
+ max_length=max_seq_length,
530
+ )
531
 
532
+ tokenized_datasets = datasets.map(
533
+ tokenize_function,
534
+ input_columns=[text_column_name],
535
+ batched=True,
536
+ num_proc=data_args.preprocessing_num_workers,
537
+ remove_columns=column_names,
538
+ load_from_cache_file=not data_args.overwrite_cache,
 
 
 
 
 
 
539
  )
540
 
541
+ else:
542
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
543
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
544
+ # efficient when it receives the `special_tokens_mask`.
545
+ def tokenize_function(examples):
546
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
547
+
548
+ tokenized_datasets = datasets.map(
549
+ tokenize_function,
550
+ batched=True,
551
+ num_proc=data_args.preprocessing_num_workers,
552
+ remove_columns=column_names,
553
+ load_from_cache_file=not data_args.overwrite_cache,
554
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
 
556
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
557
+ # max_seq_length.
558
+ def group_texts(examples):
559
+ # Concatenate all texts.
560
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
561
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
562
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
563
+ # customize this part to your needs.
564
+ if total_length >= max_seq_length:
565
+ total_length = (total_length // max_seq_length) * max_seq_length
566
+ # Split by chunks of max_len.
567
+ result = {
568
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
569
+ for k, t in concatenated_examples.items()
570
+ }
571
+ return result
572
+
573
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
574
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
575
+ # might be slower to preprocess.
576
+ #
577
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
578
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
579
+ tokenized_datasets = tokenized_datasets.map(
580
+ group_texts,
581
+ batched=True,
582
+ num_proc=data_args.preprocessing_num_workers,
583
+ load_from_cache_file=not data_args.overwrite_cache,
584
+ )
585
 
586
+ #tokenized_datasets.save_to_disk("/data/tokenized_data")
587
+ #print ("tokenized_datasets saved to disk")
588
 
589
+
590
  # Enable tensorboard only on the master node
591
  has_tensorboard = is_tensorboard_available()
592
  if has_tensorboard and jax.process_index() == 0:
604
  "Unable to display metrics through TensorBoard because the package is not installed: "
605
  "Please run pip install tensorboard to enable."
606
  )
 
607
  has_wandb = find_spec("wandb") is not None
608
  if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
609
  try:
619
  except ImportError as e:
620
  print(e)
621
  has_wandb = False
 
622
  # Data collator
623
  # This one will take care of randomly masking the tokens.
624
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
638
 
639
  # Store some constant
640
  num_epochs = int(training_args.num_train_epochs)
641
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
642
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
643
 
644
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
645
 
646
  # Create learning rate schedule
647
  warmup_fn = optax.linear_schedule(
676
  learning_rate=linear_decay_lr_schedule_fn,
677
  )
678
  else:
679
+ from optax import clip_by_global_norm
680
  optimizer = optax.adamw(
681
  learning_rate=linear_decay_lr_schedule_fn,
682
  b1=training_args.adam_beta1,
685
  weight_decay=training_args.weight_decay,
686
  mask=decay_mask_fn,
687
  )
688
+ optimizer = optax.chain(
689
+ optax.clip_by_global_norm(1.),
690
+ optimizer
691
+ )
692
 
693
+ if training_args.gradient_accumulation_steps > 1:
694
+ optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
695
+ grad_accum_steps = training_args.gradient_accumulation_steps
696
 
697
  # Setup train state
698
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
699
+
 
 
700
  if training_args.resume_from_checkpoint:
701
+ state = restore_model_checkpoint(training_args.resume_from_checkpoint, state)
702
+ resume_step = mb_item(state.step)
703
+ if training_args.adafactor:
704
+ state = fake_update(state)
705
  else:
706
  resume_step = 0
707
+
708
 
709
  # Define gradient update step fn
710
  def train_step(state, batch, dropout_rng):
722
  # take average
723
  loss = loss.sum() / label_mask.sum()
724
 
725
+ return loss
726
 
727
  grad_fn = jax.value_and_grad(loss_fn)
728
+ loss, grad = grad_fn(state.params)
729
+ grad = jax.lax.pmean(grad, "batch")
730
+ new_state = state.apply_gradients(grads=grad)
731
+
 
 
 
 
 
 
 
 
 
 
 
 
732
  metrics = jax.lax.pmean(
733
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch"
734
  )
735
 
 
736
  return new_state, metrics, new_dropout_rng
737
 
738
  # Create parallel version of the train step
763
  state = jax_utils.replicate(state)
764
 
765
  train_time = 0
766
+ steps_per_epoch = len(tokenized_datasets["train"]) // train_batch_size
767
+ resume_epoch = resume_step // (steps_per_epoch * grad_accum_steps)
768
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... ({resume_epoch+1}/{num_epochs})", position=0)
769
+ logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}")
770
  for epoch in epochs:
771
  # ======================== Training ================================
772
  train_start = time.time()
774
 
775
  # Create sampling rng
776
  rng, input_rng = jax.random.split(rng)
 
777
 
778
  # Generate an epoch by shuffling sampling indices from the train dataset
779
+ num_train_samples = len(tokenized_datasets["train"])
780
  train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
781
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
782
 
783
  # Gather the indexes for creating the batch and do a training step
784
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step // grad_accum_steps)):
785
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
786
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
 
787
 
788
  # Model forward
789
  model_inputs = shard(model_inputs.data)
790
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
791
  train_metrics.append(train_metric)
792
 
793
+ cur_step = epoch * (num_train_samples // train_batch_size * grad_accum_steps) + step
794
  if cur_step < resume_step:
795
  continue
796
 
797
+ if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
798
  # Save metrics
799
  train_metric = jax_utils.unreplicate(train_metric)
800
  train_time += time.time() - train_start
801
  if has_tensorboard and jax.process_index() == 0:
802
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
803
+
804
  if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
805
  # TODO: add accumulation of metrics
806
  _metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
807
  wandb.log({"training_step":cur_step, **_metrics}, commit=True)
808
+
809
  epochs.write(
810
  f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
811
  )
812
 
813
  train_metrics = []
814
 
815
+ if cur_step % training_args.eval_steps * grad_accum_steps == 0 and cur_step > 0:
816
  # ======================== Evaluating ==============================
817
+ num_eval_samples = len(tokenized_datasets["validation"])
818
  eval_samples_idx = jnp.arange(num_eval_samples)
819
  eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
820
 
821
  eval_metrics = []
822
  for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
823
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
824
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
825
 
826
  # Model forward
840
  # Save metrics
841
  if has_tensorboard and jax.process_index() == 0:
842
  write_eval_metric(summary_writer, eval_metrics, cur_step)
 
843
  if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
844
  _metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
845
  wandb.log({"eval_step":cur_step, **_metrics})
846
 
847
+ if cur_step % training_args.save_steps == 0 * grad_accum_steps and cur_step > 0:
848
  # save checkpoint after each epoch and push checkpoint to the hub
849
  if jax.process_index() == 0:
850
+ save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer,
851
+ push_to_hub=training_args.push_to_hub)
 
 
 
 
 
 
852
  if training_args.save_total_limit is not None:
853
  rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
854
+
855
  if jax.process_index() == 0:
856
+ save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer, push_to_hub=training_args.push_to_hub)
 
 
 
 
 
 
run_mlm_flax_no_accum.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+
29
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+ from datasets import load_dataset, DatasetDict
35
+ from tqdm import tqdm
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ from flax import jax_utils, traverse_util
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard
44
+ from transformers import (
45
+ CONFIG_MAPPING,
46
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
47
+ AutoConfig,
48
+ AutoTokenizer,
49
+ FlaxAutoModelForMaskedLM,
50
+ HfArgumentParser,
51
+ PreTrainedTokenizerBase,
52
+ TensorType,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ import json
58
+ from flax.training import checkpoints
59
+ from flax.jax_utils import unreplicate
60
+ from flax.training.checkpoints import save_checkpoint, restore_checkpoint
61
+ from importlib.util import find_spec
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class ModelArguments:
70
+ """
71
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
72
+ """
73
+
74
+ model_name_or_path: Optional[str] = field(
75
+ default=None,
76
+ metadata={
77
+ "help": "The model checkpoint for weights initialization."
78
+ "Don't set if you want to train a model from scratch."
79
+ },
80
+ )
81
+ model_type: Optional[str] = field(
82
+ default=None,
83
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
84
+ )
85
+ config_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87
+ )
88
+ tokenizer_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90
+ )
91
+ cache_dir: Optional[str] = field(
92
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
93
+ )
94
+ use_fast_tokenizer: bool = field(
95
+ default=True,
96
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
97
+ )
98
+ dtype: Optional[str] = field(
99
+ default="float32",
100
+ metadata={
101
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
102
+ },
103
+ )
104
+
105
+
106
+ @dataclass
107
+ class DataTrainingArguments:
108
+ """
109
+ Arguments pertaining to what data we are going to input our model for training and eval.
110
+ """
111
+
112
+ dataset_name: Optional[str] = field(
113
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
114
+ )
115
+ dataset_config_name: Optional[str] = field(
116
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
117
+ )
118
+ train_ref_file: Optional[str] = field(
119
+ default=None,
120
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
121
+ )
122
+ validation_ref_file: Optional[str] = field(
123
+ default=None,
124
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
125
+ )
126
+ overwrite_cache: bool = field(
127
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
128
+ )
129
+
130
+
131
+
132
+ validation_split_percentage: Optional[int] = field(
133
+ default=5,
134
+ metadata={
135
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
136
+ },
137
+ )
138
+ max_seq_length: Optional[int] = field(
139
+ default=None,
140
+ metadata={
141
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
142
+ "than this will be truncated. Default to the max input length of the model."
143
+ },
144
+ )
145
+ preprocessing_num_workers: Optional[int] = field(
146
+ default=None,
147
+ metadata={"help": "The number of processes to use for the preprocessing."},
148
+ )
149
+ mlm_probability: float = field(
150
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
151
+ )
152
+ pad_to_max_length: bool = field(
153
+ default=False,
154
+ metadata={
155
+ "help": "Whether to pad all samples to `max_seq_length`. "
156
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
157
+ },
158
+ )
159
+ line_by_line: bool = field(
160
+ default=False,
161
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
162
+ )
163
+ max_eval_samples: Optional[int] = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167
+ "value if set."
168
+ },
169
+ )
170
+
171
+
172
+
173
+
174
+
175
+
176
+ @flax.struct.dataclass
177
+ class FlaxDataCollatorForLanguageModeling:
178
+ """
179
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
180
+ are not all of the same length.
181
+
182
+ Args:
183
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
184
+ The tokenizer used for encoding the data.
185
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
186
+ The probability with which to (randomly) mask tokens in the input.
187
+
188
+ .. note::
189
+
190
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
191
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
192
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
193
+ argument :obj:`return_special_tokens_mask=True`.
194
+ """
195
+
196
+ tokenizer: PreTrainedTokenizerBase
197
+ mlm_probability: float = 0.15
198
+
199
+ def __post_init__(self):
200
+ if self.tokenizer.mask_token is None:
201
+ raise ValueError(
202
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
203
+ "You should pass `mlm=False` to train on causal language modeling instead."
204
+ )
205
+
206
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
207
+ # Handle dict or lists with proper padding and conversion to tensor.
208
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
209
+
210
+ # If special token mask has been preprocessed, pop it from the dict.
211
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
212
+
213
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
214
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
215
+ )
216
+ return batch
217
+
218
+ def mask_tokens(
219
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
220
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
221
+ """
222
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
223
+ """
224
+ labels = inputs.copy()
225
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
226
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
227
+ special_tokens_mask = special_tokens_mask.astype("bool")
228
+
229
+ probability_matrix[special_tokens_mask] = 0.0
230
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
231
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
232
+
233
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
234
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
235
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
236
+
237
+ # 10% of the time, we replace masked input tokens with random word
238
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
239
+ indices_random &= masked_indices & ~indices_replaced
240
+
241
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
242
+ inputs[indices_random] = random_words[indices_random]
243
+
244
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
245
+ return inputs, labels
246
+
247
+
248
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
249
+ num_samples = len(samples_idx)
250
+ samples_to_remove = num_samples % batch_size
251
+
252
+ if samples_to_remove != 0:
253
+ samples_idx = samples_idx[:-samples_to_remove]
254
+ sections_split = num_samples // batch_size
255
+ batch_idx = np.split(samples_idx, sections_split)
256
+ return batch_idx
257
+
258
+
259
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
260
+ summary_writer.scalar("train_time", train_time, step)
261
+
262
+ train_metrics = get_metrics(train_metrics)
263
+ for key, vals in train_metrics.items():
264
+ tag = f"train_{key}"
265
+ for i, val in enumerate(vals):
266
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
267
+
268
+
269
+ def write_eval_metric(summary_writer, eval_metrics, step):
270
+ for metric_name, value in eval_metrics.items():
271
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
272
+
273
+ def rotate_checkpoints(ckpt_dir:str, save_total_limit:int):
274
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
275
+ # TODO: what to remove is decided using step number only, we might want to improve that
276
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
277
+ # sort checkpoints by step
278
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
279
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
280
+ for ckpt in ckpts_to_delete:
281
+ logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
282
+ shutil.rmtree(ckpt)
283
+
284
+
285
+ if __name__ == "__main__":
286
+ # See all possible arguments in src/transformers/training_args.py
287
+ # or by passing the --help flag to this script.
288
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
289
+
290
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
291
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
292
+ # If we pass only one argument to the script and it's the path to a json file,
293
+ # let's parse it to get our arguments.
294
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
295
+ else:
296
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
297
+
298
+ if (
299
+ os.path.exists(training_args.output_dir)
300
+ and os.listdir(training_args.output_dir)
301
+ and training_args.do_train
302
+ and not training_args.overwrite_output_dir
303
+ ):
304
+ raise ValueError(
305
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
306
+ "Use --overwrite_output_dir to overcome."
307
+ )
308
+
309
+ # Setup logging
310
+ logging.basicConfig(
311
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
312
+ level="NOTSET",
313
+ datefmt="[%X]",
314
+ )
315
+
316
+ # Log on each process the small summary:
317
+ logger = logging.getLogger(__name__)
318
+
319
+ # Set the verbosity to info of the Transformers logger (on main process only):
320
+ logger.info(f"Training/evaluation parameters {training_args}")
321
+
322
+ # Set seed before initializing model.
323
+ set_seed(training_args.seed)
324
+
325
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
326
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
327
+ # (the dataset will be downloaded automatically from the datasets Hub).
328
+ #
329
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
330
+ # 'text' is found. You can easily tweak this behavior (see below).
331
+ #
332
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
333
+ # download the dataset.
334
+ if data_args.dataset_name is not None:
335
+ # Downloading and loading a dataset from the hub.
336
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
337
+
338
+ if "validation" not in datasets.keys():
339
+ datasets["validation"] = load_dataset(
340
+ data_args.dataset_name,
341
+ data_args.dataset_config_name,
342
+ split=f"train[:{data_args.validation_split_percentage}%]",
343
+ cache_dir=model_args.cache_dir,
344
+ )
345
+ datasets["train"] = load_dataset(
346
+ data_args.dataset_name,
347
+ data_args.dataset_config_name,
348
+ split=f"train[{data_args.validation_split_percentage}%:]",
349
+ cache_dir=model_args.cache_dir,
350
+ )
351
+ else:
352
+ import glob
353
+ import random
354
+ data_files = []
355
+ def add_jsonlines_dir(path, filespec):
356
+ global data_files
357
+ data_files += glob.glob(f"{path}/{filespec}")
358
+ data_files = list(set(data_files))
359
+ print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
360
+ add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
361
+ add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
362
+ add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
363
+ random.Random(42).shuffle(data_files)
364
+ total = len(data_files)
365
+ print(total)
366
+ perc = 0.05
367
+ val_size = int(perc * total)
368
+ train_size = total - val_size
369
+ train = data_files[:train_size]
370
+ val = data_files[train_size:]
371
+ print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
372
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
373
+ load_grouped = True
374
+ if not load_grouped:
375
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val})
376
+
377
+ #from datasets import Dataset
378
+
379
+ #dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-train.arrow")
380
+ #dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-validation.arrow")
381
+
382
+
383
+ def mb_item(x):
384
+ return x.item() if hasattr(x, "item") else x
385
+
386
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
387
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
388
+
389
+ # Load pretrained model and tokenizer
390
+
391
+ # Distributed training:
392
+ # The .from_pretrained methods guarantee that only one local process can concurrently
393
+ # download model & vocab.
394
+ if model_args.config_name:
395
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
396
+ elif model_args.model_name_or_path:
397
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
398
+ else:
399
+ config = CONFIG_MAPPING[model_args.model_type]()
400
+ logger.warning("You are instantiating a new config instance from scratch.")
401
+
402
+ if model_args.tokenizer_name:
403
+ tokenizer = AutoTokenizer.from_pretrained(
404
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
405
+ )
406
+ elif model_args.model_name_or_path:
407
+ tokenizer = AutoTokenizer.from_pretrained(
408
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
409
+ )
410
+ else:
411
+ raise ValueError(
412
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
413
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
414
+ )
415
+
416
+ # Preprocessing the datasets.
417
+ # First we tokenize all the texts.
418
+
419
+ if load_grouped:
420
+ logger.info("Loading tokenized and grouped dataset")
421
+ tokenized_datasets = DatasetDict.load_from_disk("/data/tokenized_data")
422
+ logger.info("Setting max validation examples to ")
423
+ print(f"Number of validation examples {data_args.max_eval_samples}")
424
+ tokenized_datasets["train"]= tokenized_datasets["train"].select(range(20000))
425
+ if data_args.max_eval_samples is not None:
426
+ tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(data_args.max_eval_samples))
427
+ else:
428
+ if training_args.do_train:
429
+ column_names = datasets["train"].column_names
430
+ else:
431
+ column_names = datasets["validation"].column_names
432
+ text_column_name = "text" if "text" in column_names else column_names[0]
433
+
434
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
435
+
436
+ if data_args.line_by_line:
437
+ # When using line_by_line, we just tokenize each nonempty line.
438
+ padding = "max_length" if data_args.pad_to_max_length else False
439
+
440
+ def tokenize_function(examples):
441
+ # Remove empty lines
442
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
443
+ return tokenizer(
444
+ examples,
445
+ return_special_tokens_mask=True,
446
+ padding=padding,
447
+ truncation=True,
448
+ max_length=max_seq_length,
449
+ )
450
+
451
+ tokenized_datasets = datasets.map(
452
+ tokenize_function,
453
+ input_columns=[text_column_name],
454
+ batched=True,
455
+ num_proc=data_args.preprocessing_num_workers,
456
+ remove_columns=column_names,
457
+ load_from_cache_file=not data_args.overwrite_cache,
458
+ )
459
+
460
+ else:
461
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
462
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
463
+ # efficient when it receives the `special_tokens_mask`.
464
+ def tokenize_function(examples):
465
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
466
+
467
+ tokenized_datasets = datasets.map(
468
+ tokenize_function,
469
+ batched=True,
470
+ num_proc=data_args.preprocessing_num_workers,
471
+ remove_columns=column_names,
472
+ load_from_cache_file=not data_args.overwrite_cache,
473
+ )
474
+
475
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
476
+ # max_seq_length.
477
+ def group_texts(examples):
478
+ # Concatenate all texts.
479
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
480
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
481
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
482
+ # customize this part to your needs.
483
+ if total_length >= max_seq_length:
484
+ total_length = (total_length // max_seq_length) * max_seq_length
485
+ # Split by chunks of max_len.
486
+ result = {
487
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
488
+ for k, t in concatenated_examples.items()
489
+ }
490
+ return result
491
+
492
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
493
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
494
+ # might be slower to preprocess.
495
+ #
496
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
497
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
498
+ tokenized_datasets = tokenized_datasets.map(
499
+ group_texts,
500
+ batched=True,
501
+ num_proc=data_args.preprocessing_num_workers,
502
+ load_from_cache_file=not data_args.overwrite_cache,
503
+ )
504
+
505
+ #tokenized_datasets.save_to_disk("/data/tokenized_data")
506
+ #print ("tokenized_datasets saved to disk")
507
+
508
+
509
+ # Enable tensorboard only on the master node
510
+ has_tensorboard = is_tensorboard_available()
511
+ if has_tensorboard and jax.process_index() == 0:
512
+ try:
513
+ from flax.metrics.tensorboard import SummaryWriter
514
+
515
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
516
+ except ImportError as ie:
517
+ has_tensorboard = False
518
+ logger.warning(
519
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
520
+ )
521
+ else:
522
+ logger.warning(
523
+ "Unable to display metrics through TensorBoard because the package is not installed: "
524
+ "Please run pip install tensorboard to enable."
525
+ )
526
+ has_wandb = find_spec("wandb") is not None
527
+ if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
528
+ try:
529
+ import wandb
530
+ wandb.init(
531
+ entity="wandb",
532
+ project="hf-flax-pino-roberta",
533
+ sync_tensorboard=True
534
+ )
535
+ wandb.config.update(training_args)
536
+ wandb.config.update(model_args)
537
+ wandb.config.update(data_args)
538
+ except ImportError as e:
539
+ print(e)
540
+ has_wandb = False
541
+ # Data collator
542
+ # This one will take care of randomly masking the tokens.
543
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
544
+
545
+ # Initialize our training
546
+ rng = jax.random.PRNGKey(training_args.seed)
547
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
548
+
549
+ if model_args.model_name_or_path:
550
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
551
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
552
+ )
553
+ else:
554
+ model = FlaxAutoModelForMaskedLM.from_config(
555
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
556
+ )
557
+
558
+ # Store some constant
559
+ num_epochs = int(training_args.num_train_epochs)
560
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
561
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
562
+
563
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
564
+
565
+ # Create learning rate schedule
566
+ warmup_fn = optax.linear_schedule(
567
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
568
+ )
569
+ decay_fn = optax.linear_schedule(
570
+ init_value=training_args.learning_rate,
571
+ end_value=0,
572
+ transition_steps=num_train_steps - training_args.warmup_steps,
573
+ )
574
+ linear_decay_lr_schedule_fn = optax.join_schedules(
575
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
576
+ )
577
+
578
+ # We use Optax's "masking" functionality to not apply weight decay
579
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
580
+ # mask boolean with the same structure as the parameters.
581
+ # The mask is True for parameters that should be decayed.
582
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
583
+ # For other models, one should correct the layer norm parameter naming
584
+ # accordingly.
585
+ def decay_mask_fn(params):
586
+ flat_params = traverse_util.flatten_dict(params)
587
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
588
+ return traverse_util.unflatten_dict(flat_mask)
589
+
590
+ # create adam optimizer
591
+ if training_args.adafactor:
592
+ # We use the default parameters here to initialize adafactor,
593
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
594
+ optimizer = optax.adafactor(
595
+ learning_rate=linear_decay_lr_schedule_fn,
596
+ )
597
+ else:
598
+ optimizer = optax.adamw(
599
+ learning_rate=linear_decay_lr_schedule_fn,
600
+ b1=training_args.adam_beta1,
601
+ b2=training_args.adam_beta2,
602
+ eps=training_args.adam_epsilon,
603
+ weight_decay=training_args.weight_decay,
604
+ mask=decay_mask_fn,
605
+ )
606
+ optimizer = optax.chain(
607
+ optax.clip_grad_by_global_norm(1.),
608
+ optimizer
609
+ )
610
+
611
+ # Setup train state
612
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
613
+
614
+ if training_args.resume_from_checkpoint:
615
+ state = restore_checkpoint(training_args.resume_from_checkpoint, state)
616
+ resume_step = mb_item(state.step.item())
617
+ else:
618
+ resume_step = 0
619
+
620
+
621
+ # Define gradient update step fn
622
+ def train_step(state, batch, dropout_rng):
623
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
624
+
625
+ def loss_fn(params):
626
+ labels = batch.pop("labels")
627
+
628
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
629
+
630
+ # compute loss, ignore padded input tokens
631
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
632
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
633
+
634
+ # take average
635
+ loss = loss.sum() / label_mask.sum()
636
+
637
+ return loss
638
+
639
+ grad_fn = jax.value_and_grad(loss_fn)
640
+ loss, grad = grad_fn(state.params)
641
+ grad = jax.lax.pmean(grad, "batch")
642
+ new_state = state.apply_gradients(grads=grad)
643
+
644
+ metrics = jax.lax.pmean(
645
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
646
+ )
647
+
648
+ return new_state, metrics, new_dropout_rng
649
+
650
+ # Create parallel version of the train step
651
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
652
+
653
+ # Define eval fn
654
+ def eval_step(params, batch):
655
+ labels = batch.pop("labels")
656
+
657
+ logits = model(**batch, params=params, train=False)[0]
658
+
659
+ # compute loss, ignore padded input tokens
660
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
661
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
662
+
663
+ # compute accuracy
664
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
665
+
666
+ # summarize metrics
667
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
668
+ metrics = jax.lax.psum(metrics, axis_name="batch")
669
+
670
+ return metrics
671
+
672
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
673
+
674
+ # Replicate the train state on each device
675
+ state = jax_utils.replicate(state)
676
+
677
+ train_time = 0
678
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
679
+ for epoch in epochs:
680
+ # ======================== Training ================================
681
+ train_start = time.time()
682
+ train_metrics = []
683
+
684
+ # Create sampling rng
685
+ rng, input_rng = jax.random.split(rng)
686
+
687
+ # Generate an epoch by shuffling sampling indices from the train dataset
688
+ num_train_samples = len(tokenized_datasets["train"])
689
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
690
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
691
+
692
+ # Gather the indexes for creating the batch and do a training step
693
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)):
694
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
695
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
696
+
697
+ # Model forward
698
+ model_inputs = shard(model_inputs.data)
699
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
700
+ train_metrics.append(train_metric)
701
+
702
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
703
+ if cur_step < resume_step:
704
+ continue
705
+
706
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
707
+ # Save metrics
708
+ train_metric = jax_utils.unreplicate(train_metric)
709
+ train_time += time.time() - train_start
710
+ if has_tensorboard and jax.process_index() == 0:
711
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
712
+
713
+ if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
714
+ # TODO: add accumulation of metrics
715
+ _metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
716
+ wandb.log({"training_step":cur_step, **_metrics}, commit=True)
717
+
718
+ epochs.write(
719
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
720
+ )
721
+
722
+ train_metrics = []
723
+
724
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
725
+ # ======================== Evaluating ==============================
726
+ num_eval_samples = len(tokenized_datasets["validation"])
727
+ eval_samples_idx = jnp.arange(num_eval_samples)
728
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
729
+
730
+ eval_metrics = []
731
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
732
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
733
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
734
+
735
+ # Model forward
736
+ model_inputs = shard(model_inputs.data)
737
+ metrics = p_eval_step(state.params, model_inputs)
738
+ eval_metrics.append(metrics)
739
+
740
+ # normalize eval metrics
741
+ eval_metrics = get_metrics(eval_metrics)
742
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
743
+ eval_normalizer = eval_metrics.pop("normalizer")
744
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
745
+
746
+ # Update progress bar
747
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
748
+
749
+ # Save metrics
750
+ if has_tensorboard and jax.process_index() == 0:
751
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
752
+ if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
753
+ _metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
754
+ wandb.log({"eval_step":cur_step, **_metrics})
755
+
756
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
757
+ # save checkpoint after each epoch and push checkpoint to the hub
758
+ if jax.process_index() == 0:
759
+ save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
760
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
761
+ model.save_pretrained(
762
+ training_args.output_dir,
763
+ params=params,
764
+ push_to_hub=training_args.push_to_hub,
765
+ commit_message=f"Saving weights and logs of step {cur_step}",
766
+ )
767
+ if training_args.save_total_limit is not None:
768
+ rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
769
+ if jax.process_index() == 0:
770
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
771
+ model.save_pretrained(
772
+ training_args.output_dir,
773
+ params=params,
774
+ push_to_hub=training_args.push_to_hub,
775
+ commit_message=f"Saving weights and logs of step {cur_step}",
776
+ )
save_tokenized_data.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+
29
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+ from datasets import load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ from flax import jax_utils, traverse_util
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard
44
+ from transformers import (
45
+ CONFIG_MAPPING,
46
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
47
+ AutoConfig,
48
+ AutoTokenizer,
49
+ FlaxAutoModelForMaskedLM,
50
+ HfArgumentParser,
51
+ PreTrainedTokenizerBase,
52
+ TensorType,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ import json
58
+ from flax.training import checkpoints
59
+ from flax.jax_utils import unreplicate
60
+ from flax.training.checkpoints import save_checkpoint, restore_checkpoint
61
+ from importlib.util import find_spec
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class ModelArguments:
70
+ """
71
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
72
+ """
73
+
74
+ model_name_or_path: Optional[str] = field(
75
+ default=None,
76
+ metadata={
77
+ "help": "The model checkpoint for weights initialization."
78
+ "Don't set if you want to train a model from scratch."
79
+ },
80
+ )
81
+ model_type: Optional[str] = field(
82
+ default=None,
83
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
84
+ )
85
+ config_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87
+ )
88
+ tokenizer_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90
+ )
91
+ cache_dir: Optional[str] = field(
92
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
93
+ )
94
+ use_fast_tokenizer: bool = field(
95
+ default=True,
96
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
97
+ )
98
+ dtype: Optional[str] = field(
99
+ default="float32",
100
+ metadata={
101
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
102
+ },
103
+ )
104
+
105
+
106
+ @dataclass
107
+ class DataTrainingArguments:
108
+ """
109
+ Arguments pertaining to what data we are going to input our model for training and eval.
110
+ """
111
+
112
+ dataset_name: Optional[str] = field(
113
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
114
+ )
115
+ dataset_config_name: Optional[str] = field(
116
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
117
+ )
118
+ train_ref_file: Optional[str] = field(
119
+ default=None,
120
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
121
+ )
122
+ validation_ref_file: Optional[str] = field(
123
+ default=None,
124
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
125
+ )
126
+ overwrite_cache: bool = field(
127
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
128
+ )
129
+
130
+
131
+
132
+ validation_split_percentage: Optional[int] = field(
133
+ default=5,
134
+ metadata={
135
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
136
+ },
137
+ )
138
+ max_seq_length: Optional[int] = field(
139
+ default=None,
140
+ metadata={
141
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
142
+ "than this will be truncated. Default to the max input length of the model."
143
+ },
144
+ )
145
+ preprocessing_num_workers: Optional[int] = field(
146
+ default=None,
147
+ metadata={"help": "The number of processes to use for the preprocessing."},
148
+ )
149
+ mlm_probability: float = field(
150
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
151
+ )
152
+ pad_to_max_length: bool = field(
153
+ default=False,
154
+ metadata={
155
+ "help": "Whether to pad all samples to `max_seq_length`. "
156
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
157
+ },
158
+ )
159
+ line_by_line: bool = field(
160
+ default=False,
161
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
162
+ )
163
+ max_eval_samples: Optional[int] = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167
+ "value if set."
168
+ },
169
+ )
170
+
171
+
172
+
173
+
174
+
175
+
176
+ @flax.struct.dataclass
177
+ class FlaxDataCollatorForLanguageModeling:
178
+ """
179
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
180
+ are not all of the same length.
181
+
182
+ Args:
183
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
184
+ The tokenizer used for encoding the data.
185
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
186
+ The probability with which to (randomly) mask tokens in the input.
187
+
188
+ .. note::
189
+
190
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
191
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
192
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
193
+ argument :obj:`return_special_tokens_mask=True`.
194
+ """
195
+
196
+ tokenizer: PreTrainedTokenizerBase
197
+ mlm_probability: float = 0.15
198
+
199
+ def __post_init__(self):
200
+ if self.tokenizer.mask_token is None:
201
+ raise ValueError(
202
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
203
+ "You should pass `mlm=False` to train on causal language modeling instead."
204
+ )
205
+
206
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
207
+ # Handle dict or lists with proper padding and conversion to tensor.
208
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
209
+
210
+ # If special token mask has been preprocessed, pop it from the dict.
211
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
212
+
213
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
214
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
215
+ )
216
+ return batch
217
+
218
+ def mask_tokens(
219
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
220
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
221
+ """
222
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
223
+ """
224
+ labels = inputs.copy()
225
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
226
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
227
+ special_tokens_mask = special_tokens_mask.astype("bool")
228
+
229
+ probability_matrix[special_tokens_mask] = 0.0
230
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
231
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
232
+
233
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
234
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
235
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
236
+
237
+ # 10% of the time, we replace masked input tokens with random word
238
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
239
+ indices_random &= masked_indices & ~indices_replaced
240
+
241
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
242
+ inputs[indices_random] = random_words[indices_random]
243
+
244
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
245
+ return inputs, labels
246
+
247
+
248
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
249
+ num_samples = len(samples_idx)
250
+ samples_to_remove = num_samples % batch_size
251
+
252
+ if samples_to_remove != 0:
253
+ samples_idx = samples_idx[:-samples_to_remove]
254
+ sections_split = num_samples // batch_size
255
+ batch_idx = np.split(samples_idx, sections_split)
256
+ return batch_idx
257
+
258
+
259
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
260
+ summary_writer.scalar("train_time", train_time, step)
261
+
262
+ train_metrics = get_metrics(train_metrics)
263
+ for key, vals in train_metrics.items():
264
+ tag = f"train_{key}"
265
+ for i, val in enumerate(vals):
266
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
267
+
268
+
269
+ def write_eval_metric(summary_writer, eval_metrics, step):
270
+ for metric_name, value in eval_metrics.items():
271
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
272
+
273
+
274
+ if __name__ == "__main__":
275
+ # See all possible arguments in src/transformers/training_args.py
276
+ # or by passing the --help flag to this script.
277
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
278
+
279
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
280
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
281
+ # If we pass only one argument to the script and it's the path to a json file,
282
+ # let's parse it to get our arguments.
283
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
284
+ else:
285
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
286
+
287
+ if (
288
+ os.path.exists(training_args.output_dir)
289
+ and os.listdir(training_args.output_dir)
290
+ and training_args.do_train
291
+ and not training_args.overwrite_output_dir
292
+ ):
293
+ raise ValueError(
294
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
295
+ "Use --overwrite_output_dir to overcome."
296
+ )
297
+
298
+ # Setup logging
299
+ logging.basicConfig(
300
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
301
+ level="NOTSET",
302
+ datefmt="[%X]",
303
+ )
304
+
305
+ # Log on each process the small summary:
306
+ logger = logging.getLogger(__name__)
307
+
308
+ # Set the verbosity to info of the Transformers logger (on main process only):
309
+ logger.info(f"Training/evaluation parameters {training_args}")
310
+
311
+ # Set seed before initializing model.
312
+ set_seed(training_args.seed)
313
+
314
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
315
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
316
+ # (the dataset will be downloaded automatically from the datasets Hub).
317
+ #
318
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
319
+ # 'text' is found. You can easily tweak this behavior (see below).
320
+ #
321
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
322
+ # download the dataset.
323
+ if data_args.dataset_name is not None:
324
+ # Downloading and loading a dataset from the hub.
325
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
326
+
327
+ if "validation" not in datasets.keys():
328
+ datasets["validation"] = load_dataset(
329
+ data_args.dataset_name,
330
+ data_args.dataset_config_name,
331
+ split=f"train[:{data_args.validation_split_percentage}%]",
332
+ cache_dir=model_args.cache_dir,
333
+ )
334
+ datasets["train"] = load_dataset(
335
+ data_args.dataset_name,
336
+ data_args.dataset_config_name,
337
+ split=f"train[{data_args.validation_split_percentage}%:]",
338
+ cache_dir=model_args.cache_dir,
339
+ )
340
+ else:
341
+ import glob
342
+ import random
343
+ data_files = []
344
+ def add_jsonlines_dir(path, filespec):
345
+ global data_files
346
+ data_files += glob.glob(f"{path}/{filespec}")
347
+ data_files = list(set(data_files))
348
+ print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
349
+ #add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
350
+ #add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
351
+ add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
352
+ random.Random(42).shuffle(data_files)
353
+ total = len(data_files)
354
+ print(total)
355
+ perc = 0.05
356
+ val_size = int(perc * total)
357
+ train_size = total - val_size
358
+ train = data_files[5:8]
359
+ val = data_files[1:3]
360
+ print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
361
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
362
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val},cache_dir="/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723")
363
+
364
+ #from datasets import Dataset
365
+
366
+ #dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-train.arrow")
367
+ #dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-validation.arrow")
368
+
369
+
370
+ def mb_item(x):
371
+ return x.item() if hasattr(x, "item") else x
372
+
373
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
374
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
375
+
376
+ # Load pretrained model and tokenizer
377
+
378
+ # Distributed training:
379
+ # The .from_pretrained methods guarantee that only one local process can concurrently
380
+ # download model & vocab.
381
+ if model_args.config_name:
382
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
383
+ elif model_args.model_name_or_path:
384
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
385
+ else:
386
+ config = CONFIG_MAPPING[model_args.model_type]()
387
+ logger.warning("You are instantiating a new config instance from scratch.")
388
+
389
+ if model_args.tokenizer_name:
390
+ tokenizer = AutoTokenizer.from_pretrained(
391
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
392
+ )
393
+ elif model_args.model_name_or_path:
394
+ tokenizer = AutoTokenizer.from_pretrained(
395
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
396
+ )
397
+ else:
398
+ raise ValueError(
399
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
400
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
401
+ )
402
+
403
+ # Preprocessing the datasets.
404
+ # First we tokenize all the texts.
405
+ if training_args.do_train:
406
+ column_names = datasets["train"].column_names
407
+ else:
408
+ column_names = datasets["validation"].column_names
409
+ text_column_name = "text" if "text" in column_names else column_names[0]
410
+
411
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
412
+
413
+ if data_args.line_by_line:
414
+ # When using line_by_line, we just tokenize each nonempty line.
415
+ padding = "max_length" if data_args.pad_to_max_length else False
416
+
417
+ def tokenize_function(examples):
418
+ # Remove empty lines
419
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
420
+ return tokenizer(
421
+ examples,
422
+ return_special_tokens_mask=True,
423
+ padding=padding,
424
+ truncation=True,
425
+ max_length=max_seq_length,
426
+ )
427
+
428
+ tokenized_datasets = datasets.map(
429
+ tokenize_function,
430
+ input_columns=[text_column_name],
431
+ batched=True,
432
+ num_proc=data_args.preprocessing_num_workers,
433
+ remove_columns=column_names,
434
+ load_from_cache_file=not data_args.overwrite_cache,
435
+ )
436
+ tokenized_datasets.save_to_disk("/data/tokenized_data")
437
+ print ("save data")
438
+ else:
439
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
440
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
441
+ # efficient when it receives the `special_tokens_mask`.
442
+ def tokenize_function(examples):
443
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
444
+
445
+ tokenized_datasets = datasets.map(
446
+ tokenize_function,
447
+ batched=True,
448
+ num_proc=data_args.preprocessing_num_workers,
449
+ remove_columns=column_names,
450
+ load_from_cache_file=not data_args.overwrite_cache,
451
+ )
452
+
453
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
454
+ # max_seq_length.
455
+ def group_texts(examples):
456
+ # Concatenate all texts.
457
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
458
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
459
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
460
+ # customize this part to your needs.
461
+ if total_length >= max_seq_length:
462
+ total_length = (total_length // max_seq_length) * max_seq_length
463
+ # Split by chunks of max_len.
464
+ result = {
465
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
466
+ for k, t in concatenated_examples.items()
467
+ }
468
+ return result
469
+
470
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
471
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
472
+ # might be slower to preprocess.
473
+ #
474
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
475
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
476
+ tokenized_datasets = tokenized_datasets.map(
477
+ group_texts,
478
+ batched=True,
479
+ num_proc=data_args.preprocessing_num_workers,
480
+ load_from_cache_file=not data_args.overwrite_cache,
481
+ )
482
+
483
+ tokenized_datasets.save_to_disk("/data/tokenized_data")
484
+ print ("save data")
train_tokenizer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import random
3
+ from tokenizers import ByteLevelBPETokenizer
4
+ from datasets import load_dataset
5
+
6
+ data_files = []
7
+ def add_jsonlines_dir(path, filespec):
8
+ global data_files
9
+ data_files += glob.glob(f"{path}/{filespec}")
10
+ data_files = list(set(data_files))
11
+ print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
12
+ add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
13
+ add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
14
+ add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
15
+ random.Random(42).shuffle(data_files)
16
+ total = len(data_files)
17
+ print(total)
18
+ perc = 0.05
19
+ val_size = int(perc * total)
20
+ train_size = total - val_size
21
+ train = data_files[:train_size]
22
+ val = data_files[train_size:]
23
+ print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
24
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
25
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val})
26
+
27
+
28
+
29
+ tokenizer = ByteLevelBPETokenizer()
30
+
31
+ def batch_iterator(batch_size=1000):
32
+ for i in range(0, len(datasets), batch_size):
33
+ yield datasets["train"][i: i + batch_size]["text"]
34
+
35
+ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50358, min_frequency=2, special_tokens=[
36
+ "<s>",
37
+ "<pad>",
38
+ "</s>",
39
+ "<unk>",
40
+ "<mask>",
41
+ ])
42
+
43
+ tokenizer.save("tokenizer.json")
wandb/debug-internal.log CHANGED
@@ -1 +1 @@
1
- run-20210713_010630-14xhiyhf/logs/debug-internal.log
1
+ run-20210714_210351-1msvb4w4/logs/debug-internal.log
wandb/debug.log CHANGED
@@ -1 +1 @@
1
- run-20210713_010630-14xhiyhf/logs/debug.log
1
+ run-20210714_210351-1msvb4w4/logs/debug.log
wandb/latest-run CHANGED
@@ -1 +1 @@
1
- run-20210713_010630-14xhiyhf
1
+ run-20210714_210351-1msvb4w4
wandb/run-20210713_010630-14xhiyhf/files/output.log CHANGED
@@ -16222,3 +16222,12 @@ Training...: 64%|████████████▊ | 59500/92767 [9
16222
 
16223
  Training...: 65%|████████████▉ | 60000/92767 [9:35:07<5:11:39, 1.75it/s]
16224
  git-lfs/2.9.2 (GitHub; linux amd64; go 1.13.5)92767 [9:35:07<5:11:39, 1.75it/s]
 
 
 
 
 
 
 
 
 
16222
 
16223
  Training...: 65%|████████████▉ | 60000/92767 [9:35:07<5:11:39, 1.75it/s]
16224
  git-lfs/2.9.2 (GitHub; linux amd64; go 1.13.5)92767 [9:35:07<5:11:39, 1.75it/s]
16225
+ [10:43:30] - DEBUG - huggingface_hub.repository - [Repository] is a valid git repo
16226
+ [10:44:08] - INFO - huggingface_hub.repository - Uploading LFS objects: 100% (3/3), 1.0 GB | 43 MB/s, done.
16227
+ [10:44:09] - INFO - absl - Saving checkpoint at step: 60000
16228
+ tcmalloc: large alloc 1363968000 bytes == 0x2ed6e2000 @ 0x7f170bb8c680 0x7f170bbacbdd 0x7f143fe0e20d 0x7f143fe1c340 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe17bd3 0x7f143fe181fe 0x504d56 0x56acb6 0x568d9a 0x5f5b33 0x56bc9b 0x5f5956 0x56aadf 0x5f5956 0x56fb87 0x568d9a 0x5f5b33 0x56bc9b 0x568d9a 0x68cdc7
16229
+ [10:44:13] - INFO - absl - Saved checkpoint at checkpoint_60000
16230
+
16231
+
16232
+
16233
+
wandb/run-20210713_010630-14xhiyhf/logs/debug-internal.log CHANGED
@@ -22396,3 +22396,27 @@
22396
  2021-07-13 10:43:28,960 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/wandb-summary.json
22397
  2021-07-13 10:43:29,961 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22398
  2021-07-13 10:43:31,962 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22396
  2021-07-13 10:43:28,960 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/wandb-summary.json
22397
  2021-07-13 10:43:29,961 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22398
  2021-07-13 10:43:31,962 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22399
+ 2021-07-13 10:43:36,601 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
22400
+ 2021-07-13 10:43:36,601 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
22401
+ 2021-07-13 10:43:51,734 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
22402
+ 2021-07-13 10:43:51,734 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
22403
+ 2021-07-13 10:43:55,447 DEBUG SenderThread:332390 [sender.py:send():179] send: stats
22404
+ 2021-07-13 10:44:06,865 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
22405
+ 2021-07-13 10:44:06,866 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
22406
+ 2021-07-13 10:44:09,977 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22407
+ 2021-07-13 10:44:14,979 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22408
+ 2021-07-13 10:44:16,979 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22409
+ 2021-07-13 10:44:18,980 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22410
+ 2021-07-13 10:44:20,981 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22411
+ 2021-07-13 10:44:22,005 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
22412
+ 2021-07-13 10:44:22,005 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
22413
+ 2021-07-13 10:44:22,982 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22414
+ 2021-07-13 10:44:23,482 WARNING MainThread:332390 [internal.py:wandb_internal():147] Internal process interrupt: 1
22415
+ 2021-07-13 10:44:24,702 WARNING MainThread:332390 [internal.py:wandb_internal():147] Internal process interrupt: 2
22416
+ 2021-07-13 10:44:24,703 ERROR MainThread:332390 [internal.py:wandb_internal():150] Internal process interrupted.
22417
+ 2021-07-13 10:44:24,982 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
22418
+ 2021-07-13 10:44:25,021 INFO SenderThread:332390 [sender.py:finish():945] shutting down sender
22419
+ 2021-07-13 10:44:25,022 INFO SenderThread:332390 [dir_watcher.py:finish():282] shutting down directory watcher
22420
+ 2021-07-13 10:44:25,022 INFO WriterThread:332390 [datastore.py:close():288] close: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb
22421
+ 2021-07-13 10:44:25,022 INFO HandlerThread:332390 [handler.py:finish():638] shutting down handler
22422
+ 2021-07-13 10:44:25,103 INFO MainThread:332390 [internal.py:handle_exit():78] Internal process exited
wandb/run-20210713_010630-14xhiyhf/logs/debug.log CHANGED
@@ -23,3 +23,5 @@ config: {}
23
  2021-07-13 01:06:32,711 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 2, 'per_device_eval_batch_size': 2, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0095, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 5.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 5000, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './runs/Jul13_01-05-41_t1v-n-f5c06ea1-w-0', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 500, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 20000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 92768, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': True, 'resume_from_checkpoint': None, 'push_to_hub_model_id': '', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': ''}
24
  2021-07-13 01:06:32,712 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': None, 'model_type': 'big_bird', 'config_name': './', 'tokenizer_name': './', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
25
  2021-07-13 01:06:32,714 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': None, 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 4096, 'preprocessing_num_workers': 64, 'mlm_probability': 0.15, 'pad_to_max_length': False, 'line_by_line': False}
 
 
23
  2021-07-13 01:06:32,711 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 2, 'per_device_eval_batch_size': 2, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0095, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 5.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 5000, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './runs/Jul13_01-05-41_t1v-n-f5c06ea1-w-0', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 500, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 20000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 92768, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': True, 'resume_from_checkpoint': None, 'push_to_hub_model_id': '', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': ''}
24
  2021-07-13 01:06:32,712 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': None, 'model_type': 'big_bird', 'config_name': './', 'tokenizer_name': './', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
25
  2021-07-13 01:06:32,714 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': None, 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 4096, 'preprocessing_num_workers': 64, 'mlm_probability': 0.15, 'pad_to_max_length': False, 'line_by_line': False}
26
+ 2021-07-13 10:44:23,634 INFO MainThread:330819 [wandb_run.py:_atexit_cleanup():1593] got exitcode: 255
27
+ 2021-07-13 10:44:23,634 INFO MainThread:330819 [wandb_run.py:_restore():1565] restore
wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb CHANGED
Binary files a/wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb and b/wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb differ
wandb/run-20210713_104745-1rl2j7or/files/config.yaml ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.10.33
7
+ framework: huggingface
8
+ huggingface_version: 4.9.0.dev0
9
+ is_jupyter_run: false
10
+ is_kaggle_kernel: false
11
+ python_version: 3.8.10
12
+ t:
13
+ 1:
14
+ - 3
15
+ - 11
16
+ 4: 3.8.10
17
+ 5: 0.10.33
18
+ 6: 4.9.0.dev0
19
+ 8:
20
+ - 5
21
+ adafactor:
22
+ desc: null
23
+ value: false
24
+ adam_beta1:
25
+ desc: null
26
+ value: 0.9
27
+ adam_beta2:
28
+ desc: null
29
+ value: 0.98
30
+ adam_epsilon:
31
+ desc: null
32
+ value: 1.0e-08
33
+ cache_dir:
34
+ desc: null
35
+ value: null
36
+ config_name:
37
+ desc: null
38
+ value: ./
39
+ dataloader_drop_last:
40
+ desc: null
41
+ value: false
42
+ dataloader_num_workers:
43
+ desc: null
44
+ value: 0
45
+ dataloader_pin_memory:
46
+ desc: null
47
+ value: true
48
+ dataset_config_name:
49
+ desc: null
50
+ value: null
51
+ dataset_name:
52
+ desc: null
53
+ value: null
54
+ ddp_find_unused_parameters:
55
+ desc: null
56
+ value: null
57
+ debug:
58
+ desc: null
59
+ value: []
60
+ deepspeed:
61
+ desc: null
62
+ value: null
63
+ disable_tqdm:
64
+ desc: null
65
+ value: false
66
+ do_eval:
67
+ desc: null
68
+ value: false
69
+ do_predict:
70
+ desc: null
71
+ value: false
72
+ do_train:
73
+ desc: null
74
+ value: false
75
+ dtype:
76
+ desc: null
77
+ value: float32
78
+ eval_accumulation_steps:
79
+ desc: null
80
+ value: null
81
+ eval_steps:
82
+ desc: null
83
+ value: 100001
84
+ evaluation_strategy:
85
+ desc: null
86
+ value: IntervalStrategy.NO
87
+ fp16:
88
+ desc: null
89
+ value: false
90
+ fp16_backend:
91
+ desc: null
92
+ value: auto
93
+ fp16_full_eval:
94
+ desc: null
95
+ value: false
96
+ fp16_opt_level:
97
+ desc: null
98
+ value: O1
99
+ gradient_accumulation_steps:
100
+ desc: null
101
+ value: 2
102
+ greater_is_better:
103
+ desc: null
104
+ value: null
105
+ group_by_length:
106
+ desc: null
107
+ value: false
108
+ ignore_data_skip:
109
+ desc: null
110
+ value: false
111
+ label_names:
112
+ desc: null
113
+ value: null
114
+ label_smoothing_factor:
115
+ desc: null
116
+ value: 0.0
117
+ learning_rate:
118
+ desc: null
119
+ value: 5.0e-05
120
+ length_column_name:
121
+ desc: null
122
+ value: length
123
+ line_by_line:
124
+ desc: null
125
+ value: false
126
+ load_best_model_at_end:
127
+ desc: null
128
+ value: false
129
+ local_rank:
130
+ desc: null
131
+ value: -1
132
+ log_level:
133
+ desc: null
134
+ value: -1
135
+ log_level_replica:
136
+ desc: null
137
+ value: -1
138
+ log_on_each_node:
139
+ desc: null
140
+ value: true
141
+ logging_dir:
142
+ desc: null
143
+ value: ./runs/Jul13_10-47-16_t1v-n-f5c06ea1-w-0
144
+ logging_first_step:
145
+ desc: null
146
+ value: false
147
+ logging_steps:
148
+ desc: null
149
+ value: 50
150
+ logging_strategy:
151
+ desc: null
152
+ value: IntervalStrategy.STEPS
153
+ lr_scheduler_type:
154
+ desc: null
155
+ value: SchedulerType.LINEAR
156
+ max_grad_norm:
157
+ desc: null
158
+ value: 1.0
159
+ max_seq_length:
160
+ desc: null
161
+ value: 4096
162
+ max_steps:
163
+ desc: null
164
+ value: -1
165
+ metric_for_best_model:
166
+ desc: null
167
+ value: null
168
+ mlm_probability:
169
+ desc: null
170
+ value: 0.15
171
+ model_name_or_path:
172
+ desc: null
173
+ value: null
174
+ model_type:
175
+ desc: null
176
+ value: big_bird
177
+ mp_parameters:
178
+ desc: null
179
+ value: ''
180
+ no_cuda:
181
+ desc: null
182
+ value: false
183
+ num_train_epochs:
184
+ desc: null
185
+ value: 5.0
186
+ output_dir:
187
+ desc: null
188
+ value: ./
189
+ overwrite_cache:
190
+ desc: null
191
+ value: false
192
+ overwrite_output_dir:
193
+ desc: null
194
+ value: true
195
+ pad_to_max_length:
196
+ desc: null
197
+ value: false
198
+ past_index:
199
+ desc: null
200
+ value: -1
201
+ per_device_eval_batch_size:
202
+ desc: null
203
+ value: 2
204
+ per_device_train_batch_size:
205
+ desc: null
206
+ value: 2
207
+ per_gpu_eval_batch_size:
208
+ desc: null
209
+ value: null
210
+ per_gpu_train_batch_size:
211
+ desc: null
212
+ value: null
213
+ prediction_loss_only:
214
+ desc: null
215
+ value: false
216
+ preprocessing_num_workers:
217
+ desc: null
218
+ value: 64
219
+ push_to_hub:
220
+ desc: null
221
+ value: true
222
+ push_to_hub_model_id:
223
+ desc: null
224
+ value: ''
225
+ push_to_hub_organization:
226
+ desc: null
227
+ value: null
228
+ push_to_hub_token:
229
+ desc: null
230
+ value: null
231
+ remove_unused_columns:
232
+ desc: null
233
+ value: true
234
+ report_to:
235
+ desc: null
236
+ value:
237
+ - tensorboard
238
+ - wandb
239
+ resume_from_checkpoint:
240
+ desc: null
241
+ value: null
242
+ run_name:
243
+ desc: null
244
+ value: ./
245
+ save_on_each_node:
246
+ desc: null
247
+ value: false
248
+ save_steps:
249
+ desc: null
250
+ value: 20000
251
+ save_strategy:
252
+ desc: null
253
+ value: IntervalStrategy.STEPS
254
+ save_total_limit:
255
+ desc: null
256
+ value: 5
257
+ seed:
258
+ desc: null
259
+ value: 42
260
+ sharded_ddp:
261
+ desc: null
262
+ value: []
263
+ skip_memory_metrics:
264
+ desc: null
265
+ value: true
266
+ tokenizer_name:
267
+ desc: null
268
+ value: ./
269
+ tpu_metrics_debug:
270
+ desc: null
271
+ value: false
272
+ tpu_num_cores:
273
+ desc: null
274
+ value: null
275
+ train_file:
276
+ desc: null
277
+ value: null
278
+ train_ref_file:
279
+ desc: null
280
+ value: null
281
+ use_fast_tokenizer:
282
+ desc: null
283
+ value: true
284
+ use_legacy_prediction_loop:
285
+ desc: null
286
+ value: false
287
+ validation_file:
288
+ desc: null
289
+ value: null
290
+ validation_ref_file:
291
+ desc: null
292
+ value: null
293
+ validation_split_percentage:
294
+ desc: null
295
+ value: 5
296
+ warmup_ratio:
297
+ desc: null
298
+ value: 0.0
299
+ warmup_steps:
300
+ desc: null
301
+ value: 10
302
+ weight_decay:
303
+ desc: null
304
+ value: 0.0095
wandb/run-20210713_104745-1rl2j7or/files/output.log ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3114: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
2
+ lax._check_user_dtype_supported(dtype, "zeros")
3
+ /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
4
+ warnings.warn(
5
+ /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
6
+ warnings.warn(
7
+ Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]
8
+
9
+
10
+
11
+
12
+
13
+
14
+
15
+ Training...: 60%|██████████████████ | 50/83 [01:32<00:23, 1.40it/s]
16
+
17
+
18
+
19
+ Epoch ... (1/5): 20%|█████▍ | 1/5 [02:00<08:02, 120.70s/it]
20
+
21
+ Training...: 16%|████▋ | 13/83 [00:07<00:53, 1.32it/s]
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+ Training...: 78%|███████████████████████▍ | 65/83 [00:44<00:24, 1.38s/it]
30
+
31
+ Epoch ... (1/5): 40%|███████████▏ | 2/5 [03:06<04:25, 88.56s/it]
32
+
33
+ Training...: 22%|██████▉ | 18/83 [00:01<00:07, 9.26it/s]
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+ Epoch ... (1/5): 60%|████████████████▊ | 3/5 [04:12<02:36, 78.08s/it]s]
43
+ Step... (150 | Loss: 7.8581647872924805, Learning Rate: 2.256410152767785e-05)
44
+
45
+ Training...: 33%|███████████ | 27/83 [00:03<00:06, 9.31it/s]
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+ Training...: 93%|███████████████████████████████▌ | 77/83 [00:32<00:04, 1.41it/s]
54
+
55
+ Epoch ... (1/5): 80%|██████████████████████▍ | 4/5 [05:18<01:13, 73.25s/it]/it]
56
+
57
+
wandb/run-20210713_104745-1rl2j7or/files/requirements.txt ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ astunparse==1.6.3
4
+ async-timeout==3.0.1
5
+ attrs==21.2.0
6
+ cachetools==4.2.2
7
+ certifi==2021.5.30
8
+ chardet==4.0.0
9
+ chex==0.0.8
10
+ click==8.0.1
11
+ configparser==5.0.2
12
+ cycler==0.10.0
13
+ datasets==1.9.1.dev0
14
+ dill==0.3.4
15
+ dm-tree==0.1.6
16
+ docker-pycreds==0.4.0
17
+ filelock==3.0.12
18
+ flatbuffers==1.12
19
+ flax==0.3.4
20
+ fsspec==2021.6.1
21
+ gast==0.4.0
22
+ gitdb==4.0.7
23
+ gitpython==3.1.18
24
+ google-auth-oauthlib==0.4.4
25
+ google-auth==1.32.1
26
+ google-pasta==0.2.0
27
+ grpcio==1.34.1
28
+ h5py==3.1.0
29
+ huggingface-hub==0.0.12
30
+ idna==2.10
31
+ jax==0.2.16
32
+ jaxlib==0.1.68
33
+ joblib==1.0.1
34
+ keras-nightly==2.5.0.dev2021032900
35
+ keras-preprocessing==1.1.2
36
+ kiwisolver==1.3.1
37
+ libtpu-nightly==0.1.dev20210615
38
+ markdown==3.3.4
39
+ matplotlib==3.4.2
40
+ msgpack==1.0.2
41
+ multidict==5.1.0
42
+ multiprocess==0.70.12.2
43
+ numpy==1.19.5
44
+ oauthlib==3.1.1
45
+ opt-einsum==3.3.0
46
+ optax==0.0.9
47
+ packaging==21.0
48
+ pandas==1.3.0
49
+ pathtools==0.1.2
50
+ pillow==8.3.1
51
+ pip==20.0.2
52
+ pkg-resources==0.0.0
53
+ promise==2.3
54
+ protobuf==3.17.3
55
+ psutil==5.8.0
56
+ pyarrow==4.0.1
57
+ pyasn1-modules==0.2.8
58
+ pyasn1==0.4.8
59
+ pyparsing==2.4.7
60
+ python-dateutil==2.8.1
61
+ pytz==2021.1
62
+ pyyaml==5.4.1
63
+ regex==2021.7.6
64
+ requests-oauthlib==1.3.0
65
+ requests==2.25.1
66
+ rsa==4.7.2
67
+ sacremoses==0.0.45
68
+ scipy==1.7.0
69
+ sentry-sdk==1.3.0
70
+ setuptools==44.0.0
71
+ shortuuid==1.0.1
72
+ six==1.15.0
73
+ smmap==4.0.0
74
+ subprocess32==3.5.4
75
+ tensorboard-data-server==0.6.1
76
+ tensorboard-plugin-wit==1.8.0
77
+ tensorboard==2.5.0
78
+ tensorflow-estimator==2.5.0
79
+ tensorflow==2.5.0
80
+ termcolor==1.1.0
81
+ tokenizers==0.10.3
82
+ toolz==0.11.1
83
+ tqdm==4.61.2
84
+ transformers==4.9.0.dev0
85
+ typing-extensions==3.7.4.3
86
+ urllib3==1.26.6
87
+ wandb==0.10.33
88
+ werkzeug==2.0.1
89
+ wheel==0.36.2
90
+ wrapt==1.12.1
91
+ xxhash==2.0.2
92
+ yarl==1.6.3
wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-13T10:47:47.215746",
5
+ "startedAt": "2021-07-13T10:47:45.129053",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--push_to_hub",
11
+ "--output_dir=./",
12
+ "--model_type=big_bird",
13
+ "--config_name=./",
14
+ "--tokenizer_name=./",
15
+ "--max_seq_length=4096",
16
+ "--weight_decay=0.0095",
17
+ "--warmup_steps=10",
18
+ "--overwrite_output_dir",
19
+ "--adam_beta1=0.9",
20
+ "--adam_beta2=0.98",
21
+ "--logging_steps=50",
22
+ "--eval_steps=100001",
23
+ "--num_train_epochs=5",
24
+ "--preprocessing_num_workers=64",
25
+ "--save_steps=20000",
26
+ "--learning_rate=5e-5",
27
+ "--per_device_train_batch_size=2",
28
+ "--per_device_eval_batch_size=2",
29
+ "--save_total_limit=5",
30
+ "--gradient_accumulation_steps=2"
31
+ ],
32
+ "state": "running",
33
+ "program": "./run_mlm_flax.py",
34
+ "codePath": "run_mlm_flax.py",
35
+ "git": {
36
+ "remote": "https://huggingface.co/flax-community/pino-roberta-base",
37
+ "commit": "bc11ccfe77236f87575711b26034b9751449de4b"
38
+ },
39
+ "email": null,
40
+ "root": "/home/dat/pino-roberta-base",
41
+ "host": "t1v-n-f5c06ea1-w-0",
42
+ "username": "dat",
43
+ "executable": "/home/dat/pino/bin/python"
44
+ }
wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
1
+ {"training_step": 200, "learning_rate": 1.0769229447760154e-05, "train_loss": 7.618040084838867, "_runtime": 333, "_timestamp": 1626173598, "_step": 6}
wandb/run-20210713_104745-1rl2j7or/logs/debug-internal.log ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2021-07-13 10:47:45,828 INFO MainThread:342403 [internal.py:wandb_internal():88] W&B internal server running at pid: 342403, started at: 2021-07-13 10:47:45.828158
2
+ 2021-07-13 10:47:45,830 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: check_version
3
+ 2021-07-13 10:47:45,830 INFO WriterThread:342403 [datastore.py:open_for_write():80] open: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb
4
+ 2021-07-13 10:47:45,831 DEBUG SenderThread:342403 [sender.py:send():179] send: header
5
+ 2021-07-13 10:47:45,831 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: check_version
6
+ 2021-07-13 10:47:45,871 DEBUG SenderThread:342403 [sender.py:send():179] send: run
7
+ 2021-07-13 10:47:46,041 INFO SenderThread:342403 [dir_watcher.py:__init__():168] watching files in: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files
8
+ 2021-07-13 10:47:46,041 INFO SenderThread:342403 [sender.py:_start_run_threads():716] run started: 1rl2j7or with start time 1626173265
9
+ 2021-07-13 10:47:46,041 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
10
+ 2021-07-13 10:47:46,041 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: run_start
11
+ 2021-07-13 10:47:46,042 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
12
+ 2021-07-13 10:47:47,043 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
13
+ 2021-07-13 10:47:47,215 DEBUG HandlerThread:342403 [meta.py:__init__():39] meta init
14
+ 2021-07-13 10:47:47,215 DEBUG HandlerThread:342403 [meta.py:__init__():53] meta init done
15
+ 2021-07-13 10:47:47,215 DEBUG HandlerThread:342403 [meta.py:probe():210] probe
16
+ 2021-07-13 10:47:47,217 DEBUG HandlerThread:342403 [meta.py:_setup_git():200] setup git
17
+ 2021-07-13 10:47:47,250 DEBUG HandlerThread:342403 [meta.py:_setup_git():207] setup git done
18
+ 2021-07-13 10:47:47,250 DEBUG HandlerThread:342403 [meta.py:_save_pip():57] save pip
19
+ 2021-07-13 10:47:47,251 DEBUG HandlerThread:342403 [meta.py:_save_pip():71] save pip done
20
+ 2021-07-13 10:47:47,251 DEBUG HandlerThread:342403 [meta.py:probe():252] probe done
21
+ 2021-07-13 10:47:47,255 DEBUG SenderThread:342403 [sender.py:send():179] send: files
22
+ 2021-07-13 10:47:47,255 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-metadata.json with policy now
23
+ 2021-07-13 10:47:47,262 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
24
+ 2021-07-13 10:47:47,262 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
25
+ 2021-07-13 10:47:47,394 DEBUG SenderThread:342403 [sender.py:send():179] send: config
26
+ 2021-07-13 10:47:47,394 DEBUG SenderThread:342403 [sender.py:send():179] send: config
27
+ 2021-07-13 10:47:47,394 DEBUG SenderThread:342403 [sender.py:send():179] send: config
28
+ 2021-07-13 10:47:47,719 INFO Thread-11 :342403 [upload_job.py:push():137] Uploaded file /tmp/tmpta17r5ywwandb/1f1555en-wandb-metadata.json
29
+ 2021-07-13 10:47:48,042 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json
30
+ 2021-07-13 10:47:48,042 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/requirements.txt
31
+ 2021-07-13 10:47:48,042 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
32
+ 2021-07-13 10:48:02,047 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
33
+ 2021-07-13 10:48:02,398 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
34
+ 2021-07-13 10:48:02,398 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
35
+ 2021-07-13 10:48:04,048 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
36
+ 2021-07-13 10:48:15,296 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
37
+ 2021-07-13 10:48:17,054 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/config.yaml
38
+ 2021-07-13 10:48:17,555 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
39
+ 2021-07-13 10:48:17,556 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
40
+ 2021-07-13 10:48:32,709 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
41
+ 2021-07-13 10:48:32,710 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
42
+ 2021-07-13 10:48:45,371 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
43
+ 2021-07-13 10:48:47,840 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
44
+ 2021-07-13 10:48:47,840 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
45
+ 2021-07-13 10:49:02,980 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
46
+ 2021-07-13 10:49:02,980 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
47
+ 2021-07-13 10:49:15,445 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
48
+ 2021-07-13 10:49:18,113 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
49
+ 2021-07-13 10:49:18,113 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
50
+ 2021-07-13 10:49:24,080 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
51
+ 2021-07-13 10:49:26,080 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
52
+ 2021-07-13 10:49:28,081 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
53
+ 2021-07-13 10:49:30,082 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
54
+ 2021-07-13 10:49:32,083 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
55
+ 2021-07-13 10:49:33,242 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
56
+ 2021-07-13 10:49:33,243 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
57
+ 2021-07-13 10:49:34,084 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
58
+ 2021-07-13 10:49:36,084 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
59
+ 2021-07-13 10:49:45,514 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
60
+ 2021-07-13 10:49:48,375 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
61
+ 2021-07-13 10:49:48,375 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
62
+ 2021-07-13 10:49:58,179 DEBUG SenderThread:342403 [sender.py:send():179] send: history
63
+ 2021-07-13 10:49:58,180 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
64
+ 2021-07-13 10:49:58,180 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
65
+ 2021-07-13 10:49:59,093 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
66
+ 2021-07-13 10:50:00,093 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
67
+ 2021-07-13 10:50:02,094 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
68
+ 2021-07-13 10:50:03,510 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
69
+ 2021-07-13 10:50:03,510 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
70
+ 2021-07-13 10:50:04,095 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
71
+ 2021-07-13 10:50:15,583 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
72
+ 2021-07-13 10:50:18,643 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
73
+ 2021-07-13 10:50:18,643 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
74
+ 2021-07-13 10:50:24,102 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
75
+ 2021-07-13 10:50:28,758 DEBUG SenderThread:342403 [sender.py:send():179] send: history
76
+ 2021-07-13 10:50:28,759 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
77
+ 2021-07-13 10:50:28,763 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
78
+ 2021-07-13 10:50:29,104 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
79
+ 2021-07-13 10:50:30,105 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
80
+ 2021-07-13 10:50:32,106 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
81
+ 2021-07-13 10:50:33,775 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
82
+ 2021-07-13 10:50:33,776 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
83
+ 2021-07-13 10:50:34,107 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
84
+ 2021-07-13 10:50:36,107 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
85
+ 2021-07-13 10:50:38,108 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
86
+ 2021-07-13 10:50:40,109 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
87
+ 2021-07-13 10:50:42,110 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
88
+ 2021-07-13 10:50:45,653 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
89
+ 2021-07-13 10:50:48,905 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
90
+ 2021-07-13 10:50:48,906 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
91
+ 2021-07-13 10:51:04,035 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
92
+ 2021-07-13 10:51:04,035 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
93
+ 2021-07-13 10:51:04,964 DEBUG SenderThread:342403 [sender.py:send():179] send: history
94
+ 2021-07-13 10:51:04,964 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
95
+ 2021-07-13 10:51:04,964 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
96
+ 2021-07-13 10:51:05,119 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
97
+ 2021-07-13 10:51:06,119 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
98
+ 2021-07-13 10:51:08,120 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
99
+ 2021-07-13 10:51:15,726 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
100
+ 2021-07-13 10:51:19,168 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
101
+ 2021-07-13 10:51:19,168 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
102
+ 2021-07-13 10:51:24,126 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
103
+ 2021-07-13 10:51:26,127 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
104
+ 2021-07-13 10:51:34,303 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
105
+ 2021-07-13 10:51:34,303 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
106
+ 2021-07-13 10:51:35,557 DEBUG SenderThread:342403 [sender.py:send():179] send: history
107
+ 2021-07-13 10:51:35,558 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
108
+ 2021-07-13 10:51:35,558 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
109
+ 2021-07-13 10:51:36,131 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
110
+ 2021-07-13 10:51:36,132 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
111
+ 2021-07-13 10:51:38,132 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
112
+ 2021-07-13 10:51:40,133 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
113
+ 2021-07-13 10:51:42,134 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
114
+ 2021-07-13 10:51:44,135 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
115
+ 2021-07-13 10:51:45,797 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
116
+ 2021-07-13 10:51:46,136 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
117
+ 2021-07-13 10:51:48,137 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
118
+ 2021-07-13 10:51:49,438 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
119
+ 2021-07-13 10:51:49,438 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
120
+ 2021-07-13 10:51:50,137 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
121
+ 2021-07-13 10:52:04,579 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
122
+ 2021-07-13 10:52:04,580 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
123
+ 2021-07-13 10:52:11,761 DEBUG SenderThread:342403 [sender.py:send():179] send: history
124
+ 2021-07-13 10:52:11,762 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
125
+ 2021-07-13 10:52:11,763 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
126
+ 2021-07-13 10:52:12,146 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
127
+ 2021-07-13 10:52:14,147 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
128
+ 2021-07-13 10:52:15,867 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
129
+ 2021-07-13 10:52:19,709 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
130
+ 2021-07-13 10:52:19,710 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
131
+ 2021-07-13 10:52:24,150 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
132
+ 2021-07-13 10:52:26,151 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
133
+ 2021-07-13 10:52:34,838 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
134
+ 2021-07-13 10:52:34,839 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
135
+ 2021-07-13 10:52:42,378 DEBUG SenderThread:342403 [sender.py:send():179] send: history
136
+ 2021-07-13 10:52:42,378 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
137
+ 2021-07-13 10:52:42,379 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
138
+ 2021-07-13 10:52:43,158 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
139
+ 2021-07-13 10:52:45,159 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
140
+ 2021-07-13 10:52:45,939 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
141
+ 2021-07-13 10:52:47,160 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
142
+ 2021-07-13 10:52:49,161 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
143
+ 2021-07-13 10:52:49,969 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
144
+ 2021-07-13 10:52:49,970 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
145
+ 2021-07-13 10:52:51,161 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
146
+ 2021-07-13 10:52:53,162 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
147
+ 2021-07-13 10:52:55,163 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
148
+ 2021-07-13 10:52:57,164 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
149
+ 2021-07-13 10:53:05,101 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
150
+ 2021-07-13 10:53:05,101 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
151
+ 2021-07-13 10:53:16,014 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
152
+ 2021-07-13 10:53:18,580 DEBUG SenderThread:342403 [sender.py:send():179] send: history
153
+ 2021-07-13 10:53:18,580 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
154
+ 2021-07-13 10:53:18,580 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
155
+ 2021-07-13 10:53:19,173 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
156
+ 2021-07-13 10:53:20,233 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
157
+ 2021-07-13 10:53:20,234 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
158
+ 2021-07-13 10:53:21,173 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
159
+ 2021-07-13 10:53:25,175 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
160
+ 2021-07-13 10:53:27,176 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
161
+ 2021-07-13 10:53:29,177 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
162
+ 2021-07-13 10:53:34,237 WARNING MainThread:342403 [internal.py:wandb_internal():147] Internal process interrupt: 1
163
+ 2021-07-13 10:53:34,484 WARNING MainThread:342403 [internal.py:wandb_internal():147] Internal process interrupt: 2
164
+ 2021-07-13 10:53:34,484 ERROR MainThread:342403 [internal.py:wandb_internal():150] Internal process interrupted.
165
+ 2021-07-13 10:53:35,385 INFO WriterThread:342403 [datastore.py:close():288] close: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb
166
+ 2021-07-13 10:53:35,409 INFO SenderThread:342403 [sender.py:finish():945] shutting down sender
167
+ 2021-07-13 10:53:35,409 INFO SenderThread:342403 [dir_watcher.py:finish():282] shutting down directory watcher
168
+ 2021-07-13 10:53:35,414 INFO HandlerThread:342403 [handler.py:finish():638] shutting down handler
169
+ 2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():312] scan: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files
170
+ 2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/requirements.txt requirements.txt
171
+ 2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log output.log
172
+ 2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json wandb-metadata.json
173
+ 2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/config.yaml config.yaml
174
+ 2021-07-13 10:53:36,181 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json wandb-summary.json
175
+ 2021-07-13 10:53:36,181 INFO SenderThread:342403 [file_pusher.py:finish():177] shutting down file pusher
176
+ 2021-07-13 10:53:36,181 INFO SenderThread:342403 [file_pusher.py:join():182] waiting for file pusher
177
+ 2021-07-13 10:53:36,622 INFO Thread-14 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/config.yaml
178
+ 2021-07-13 10:53:36,624 INFO Thread-15 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
179
+ 2021-07-13 10:53:36,634 INFO Thread-13 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
180
+ 2021-07-13 10:53:36,654 INFO Thread-12 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/requirements.txt
181
+ 2021-07-13 10:53:37,518 INFO MainThread:342403 [internal.py:handle_exit():78] Internal process exited
wandb/run-20210713_104745-1rl2j7or/logs/debug.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_setup.py:_flush():69] setting env: {}
2
+ 2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_setup.py:_flush():69] setting login settings: {}
3
+ 2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_init.py:_log_setup():337] Logging user logs to /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/logs/debug.log
4
+ 2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_init.py:_log_setup():338] Logging internal logs to /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/logs/debug-internal.log
5
+ 2021-07-13 10:47:45,131 INFO MainThread:340852 [wandb_init.py:init():370] calling init triggers
6
+ 2021-07-13 10:47:45,131 INFO MainThread:340852 [wandb_init.py:init():375] wandb.init called with sweep_config: {}
7
+ config: {}
8
+ 2021-07-13 10:47:45,131 INFO MainThread:340852 [wandb_init.py:init():419] starting backend
9
+ 2021-07-13 10:47:45,131 INFO MainThread:340852 [backend.py:_multiprocessing_setup():70] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
10
+ 2021-07-13 10:47:45,179 INFO MainThread:340852 [backend.py:ensure_launched():135] starting backend process...
11
+ 2021-07-13 10:47:45,225 INFO MainThread:340852 [backend.py:ensure_launched():139] started backend process with pid: 342403
12
+ 2021-07-13 10:47:45,228 INFO MainThread:340852 [wandb_init.py:init():424] backend started and connected
13
+ 2021-07-13 10:47:45,231 INFO MainThread:340852 [wandb_init.py:init():472] updated telemetry
14
+ 2021-07-13 10:47:45,231 INFO MainThread:340852 [wandb_init.py:init():491] communicating current version
15
+ 2021-07-13 10:47:45,870 INFO MainThread:340852 [wandb_init.py:init():496] got version response
16
+ 2021-07-13 10:47:45,870 INFO MainThread:340852 [wandb_init.py:init():504] communicating run to backend with 30 second timeout
17
+ 2021-07-13 10:47:46,040 INFO MainThread:340852 [wandb_init.py:init():529] starting run threads in backend
18
+ 2021-07-13 10:47:47,259 INFO MainThread:340852 [wandb_run.py:_console_start():1623] atexit reg
19
+ 2021-07-13 10:47:47,260 INFO MainThread:340852 [wandb_run.py:_redirect():1497] redirect: SettingsConsole.REDIRECT
20
+ 2021-07-13 10:47:47,261 INFO MainThread:340852 [wandb_run.py:_redirect():1502] Redirecting console.
21
+ 2021-07-13 10:47:47,262 INFO MainThread:340852 [wandb_run.py:_redirect():1558] Redirects installed.
22
+ 2021-07-13 10:47:47,262 INFO MainThread:340852 [wandb_init.py:init():554] run started, returning control to user process
23
+ 2021-07-13 10:47:47,268 INFO MainThread:340852 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 2, 'per_device_eval_batch_size': 2, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 2, 'eval_accumulation_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0095, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 5.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 10, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './runs/Jul13_10-47-16_t1v-n-f5c06ea1-w-0', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 50, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 20000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 100001, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': True, 'resume_from_checkpoint': None, 'push_to_hub_model_id': '', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': ''}
24
+ 2021-07-13 10:47:47,270 INFO MainThread:340852 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': None, 'model_type': 'big_bird', 'config_name': './', 'tokenizer_name': './', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'float32'}
25
+ 2021-07-13 10:47:47,271 INFO MainThread:340852 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': None, 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 4096, 'preprocessing_num_workers': 64, 'mlm_probability': 0.15, 'pad_to_max_length': False, 'line_by_line': False}
26
+ 2021-07-13 10:53:34,760 INFO MainThread:340852 [wandb_run.py:_atexit_cleanup():1593] got exitcode: 255
27
+ 2021-07-13 10:53:34,761 INFO MainThread:340852 [wandb_run.py:_restore():1565] restore
wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb ADDED
Binary file (14.8 kB). View file
wandb/run-20210713_110212-594z6oo0/files/config.yaml ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.10.33
7
+ framework: huggingface
8
+ huggingface_version: 4.9.0.dev0
9
+ is_jupyter_run: false
10
+ is_kaggle_kernel: false
11
+ python_version: 3.8.10
12
+ t:
13
+ 1:
14
+ - 3
15
+ - 11
16
+ 2:
17
+ - 3
18
+ - 11
19
+ 4: 3.8.10
20
+ 5: 0.10.33
21
+ 6: 4.9.0.dev0
22
+ 8:
23
+ - 5
24
+ adafactor:
25
+ desc: null
26
+ value: false
27
+ adam_beta1:
28
+ desc: null
29
+ value: 0.9
30
+ adam_beta2:
31
+ desc: null
32
+ value: 0.98
33
+ adam_epsilon:
34
+ desc: null
35
+ value: 1.0e-08
36
+ cache_dir:
37
+ desc: null
38
+ value: null
39
+ config_name:
40
+ desc: null
41
+ value: ./
42
+ dataloader_drop_last:
43
+ desc: null
44
+ value: false
45
+ dataloader_num_workers:
46
+ desc: null
47
+ value: 0
48
+ dataloader_pin_memory:
49
+ desc: null
50
+ value: true
51
+ dataset_config_name:
52
+ desc: null
53
+ value: null
54
+ dataset_name:
55
+ desc: null
56
+ value: null
57
+ ddp_find_unused_parameters:
58
+ desc: null
59
+ value: null
60
+ debug:
61
+ desc: null
62
+ value: []
63
+ deepspeed:
64
+ desc: null
65
+ value: null
66
+ disable_tqdm:
67
+ desc: null
68
+ value: false
69
+ do_eval:
70
+ desc: null
71
+ value: false
72
+ do_predict:
73
+ desc: null
74
+ value: false
75
+ do_train:
76
+ desc: null
77
+ value: false
78
+ dtype:
79
+ desc: null
80
+ value: float32
81
+ eval_accumulation_steps:
82
+ desc: null
83
+ value: null
84
+ eval_steps:
85
+ desc: null
86
+ value: 100001
87
+ evaluation_strategy:
88
+ desc: null
89
+ value: IntervalStrategy.NO
90
+ fp16:
91
+ desc: null
92
+ value: false
93
+ fp16_backend:
94
+ desc: null
95
+ value: auto
96
+ fp16_full_eval:
97
+ desc: null
98
+ value: false
99
+ fp16_opt_level:
100
+ desc: null
101
+ value: O1
102
+ gradient_accumulation_steps:
103
+ desc: null
104
+ value: 2
105
+ greater_is_better:
106
+ desc: null
107
+ value: null
108
+ group_by_length:
109
+ desc: null
110
+ value: false
111
+ ignore_data_skip:
112
+ desc: null
113
+ value: false
114
+ label_names:
115
+ desc: null
116
+ value: null
117
+ label_smoothing_factor:
118
+ desc: null
119
+ value: 0.0
120
+ learning_rate:
121
+ desc: null
122
+ value: 5.0e-05
123
+ length_column_name:
124
+ desc: null
125
+ value: length
126
+ line_by_line:
127
+ desc: null
128
+ value: false
129
+ load_best_model_at_end:
130
+ desc: null
131
+ value: false
132
+ local_rank:
133
+ desc: null
134
+ value: -1
135
+ log_level:
136
+ desc: null
137
+ value: -1
138
+ log_level_replica:
139
+ desc: null
140
+ value: -1
141
+ log_on_each_node:
142
+ desc: null
143
+ value: true
144
+ logging_dir:
145
+ desc: null
146
+ value: ./runs/Jul13_11-01-24_t1v-n-f5c06ea1-w-0
147
+ logging_first_step:
148
+ desc: null
149
+ value: false
150
+ logging_steps:
151
+ desc: null
152
+ value: 500
153
+ logging_strategy:
154
+ desc: null
155
+ value: IntervalStrategy.STEPS
156
+ lr_scheduler_type:
157
+ desc: null
158
+ value: SchedulerType.LINEAR
159
+ max_grad_norm:
160
+ desc: null
161
+ value: 1.0
162
+ max_seq_length:
163
+ desc: null
164
+ value: 4096
165
+ max_steps:
166
+ desc: null
167
+ value: -1
168
+ metric_for_best_model:
169
+ desc: null
170
+ value: null
171
+ mlm_probability:
172
+ desc: null
173
+ value: 0.15
174
+ model_name_or_path:
175
+ desc: null
176
+ value: null
177
+ model_type:
178
+ desc: null
179
+ value: big_bird
180
+ mp_parameters:
181
+ desc: null
182
+ value: ''
183
+ no_cuda:
184
+ desc: null
185
+ value: false
186
+ num_train_epochs:
187
+ desc: null
188
+ value: 5.0
189
+ output_dir:
190
+ desc: null
191
+ value: ./
192
+ overwrite_cache:
193
+ desc: null
194
+ value: false
195
+ overwrite_output_dir:
196
+ desc: null
197
+ value: true
198
+ pad_to_max_length:
199
+ desc: null
200
+ value: false
201
+ past_index:
202
+ desc: null
203
+ value: -1
204
+ per_device_eval_batch_size:
205
+ desc: null
206
+ value: 2
207
+ per_device_train_batch_size:
208
+ desc: null
209
+ value: 2
210
+ per_gpu_eval_batch_size:
211
+ desc: null
212
+ value: null
213
+ per_gpu_train_batch_size:
214
+ desc: null
215
+ value: null
216
+ prediction_loss_only:
217
+ desc: null
218
+ value: false
219
+ preprocessing_num_workers:
220
+ desc: null
221
+ value: 64
222
+ push_to_hub:
223
+ desc: null
224
+ value: true
225
+ push_to_hub_model_id:
226
+ desc: null
227
+ value: ''
228
+ push_to_hub_organization:
229
+ desc: null
230
+ value: null
231
+ push_to_hub_token:
232
+ desc: null
233
+ value: null
234
+ remove_unused_columns:
235
+ desc: null
236
+ value: true
237
+ report_to:
238
+ desc: null
239
+ value:
240
+ - tensorboard
241
+ - wandb
242
+ resume_from_checkpoint:
243
+ desc: null
244
+ value: null
245
+ run_name:
246
+ desc: null
247
+ value: ./
248
+ save_on_each_node:
249
+ desc: null
250
+ value: false
251
+ save_steps:
252
+ desc: null
253
+ value: 20000
254
+ save_strategy:
255
+ desc: null
256
+ value: IntervalStrategy.STEPS
257
+ save_total_limit:
258
+ desc: null
259
+ value: 5
260
+ seed:
261
+ desc: null
262
+ value: 42
263
+ sharded_ddp:
264
+ desc: null
265
+ value: []
266
+ skip_memory_metrics:
267
+ desc: null
268
+ value: true
269
+ tokenizer_name:
270
+ desc: null
271
+ value: ./
272
+ tpu_metrics_debug:
273
+ desc: null
274
+ value: false
275
+ tpu_num_cores:
276
+ desc: null
277
+ value: null
278
+ train_file:
279
+ desc: null
280
+ value: null
281
+ train_ref_file:
282
+ desc: null
283
+ value: null
284
+ use_fast_tokenizer:
285
+ desc: null
286
+ value: true
287
+ use_legacy_prediction_loop:
288
+ desc: null
289
+ value: false
290
+ validation_file:
291
+ desc: null
292
+ value: null
293
+ validation_ref_file:
294
+ desc: null
295
+ value: null
296
+ validation_split_percentage:
297
+ desc: null
298
+ value: 5
299
+ warmup_ratio:
300
+ desc: null
301
+ value: 0.0
302
+ warmup_steps:
303
+ desc: null
304
+ value: 10
305
+ weight_decay:
306
+ desc: null
307
+ value: 0.0095
wandb/run-20210713_110212-594z6oo0/files/output.log ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3114: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
2
+ lax._check_user_dtype_supported(dtype, "zeros")
3
+ /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
4
+ warnings.warn(
5
+ /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
6
+ warnings.warn(
7
+ Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]
8
+ Training...: 0%| | 0/92767 [01:25<?, ?it/s]
9
+ Epoch ... (1/5): 0%| | 0/5 [02:57<?, ?it/s]
10
+ Traceback (most recent call last):
11
+ File "./run_mlm_flax.py", line 712, in <module>
12
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
13
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
14
+ return fun(*args, **kwargs)
15
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1647, in f_pmapped
16
+ out = pxla.xla_pmap(
17
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind
18
+ return call_bind(self, fun, *args, **params)
19
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
20
+ outs = primitive.process(top_trace, fun, tracers, params)
21
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process
22
+ return trace.process_map(self, fun, tracers, params)
23
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
24
+ return primitive.impl(f, *tracers, **params)
25
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 637, in xla_pmap_impl
26
+ return compiled_fun(*args)
27
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1152, in execute_replicated
28
+ out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
29
+ jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Attempting to reserve 12.60G at the bottom of memory. That was not possible. There are 12.15G free, 0B reserved, and 12.13G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
30
+ The stack trace below excludes JAX-internal frames.
31
+ The preceding is the original exception that occurred, unmodified.
32
+ --------------------
33
+ The above exception was the direct cause of the following exception:
34
+ Traceback (most recent call last):
35
+ File "./run_mlm_flax.py", line 712, in <module>
36
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
37
+ File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1152, in execute_replicated
38
+ out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
39
+ RuntimeError: Resource exhausted: Attempting to reserve 12.60G at the bottom of memory. That was not possible. There are 12.15G free, 0B reserved, and 12.13G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).