boris commited on
Commit
d811136
2 Parent(s): e31a84f 62dad48

Merge pull request #25 from borisdayma/fix-config

Browse files
seq2seq/do_big_run.sh CHANGED
@@ -1,16 +1,16 @@
1
  python run_seq2seq_flax.py \
2
  --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \ # ignored for now in our script
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \ # ignored for now in our script
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
8
  --preprocessing_num_workers 80 \
9
- --warmup_steps 125 \
10
  --gradient_accumulation_steps 8 \
11
  --do_train \
12
  --do_eval \
13
  --adafactor \
14
- --num_train_epochs 10 \
15
  --log_model \
16
- --learning_rate 0.001
 
1
  python run_seq2seq_flax.py \
2
  --max_source_length 128 \
3
+ --train_file /data/CC12M/encoded-small-train.tsv \
4
+ --validation_file /data/CC12M/encoded-small-valid.tsv \
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
8
  --preprocessing_num_workers 80 \
9
+ --warmup_steps 250 \
10
  --gradient_accumulation_steps 8 \
11
  --do_train \
12
  --do_eval \
13
  --adafactor \
14
+ --num_train_epochs 6 \
15
  --log_model \
16
+ --learning_rate 0.005
seq2seq/do_small_run.sh CHANGED
@@ -1,7 +1,7 @@
1
  python run_seq2seq_flax.py \
2
  --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \ # ignored for now in our script
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \ # ignored for now in our script
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
@@ -13,4 +13,4 @@ python run_seq2seq_flax.py \
13
  --adafactor \
14
  --num_train_epochs 1 \
15
  --max_train_samples 20000 \
16
- --learning_rate 0.003
 
1
  python run_seq2seq_flax.py \
2
  --max_source_length 128 \
3
+ --train_file /data/CC12M/encoded-small-train.tsv \
4
+ --validation_file /data/CC12M/encoded-small-valid.tsv \
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
 
13
  --adafactor \
14
  --num_train_epochs 1 \
15
  --max_train_samples 20000 \
16
+ --learning_rate 0.005
seq2seq/run_seq2seq_flax.py CHANGED
@@ -31,6 +31,7 @@ from dataclasses import dataclass, field
31
  from functools import partial
32
  from pathlib import Path
33
  from typing import Callable, Optional
 
34
 
35
  import datasets
36
  import nltk # Here to have a nice missing dependency error message early on
@@ -44,6 +45,7 @@ import optax
44
  import transformers
45
  from filelock import FileLock
46
  from flax import jax_utils, traverse_util
 
47
  import flax.linen as nn
48
  from flax.jax_utils import unreplicate
49
  from flax.training import train_state
@@ -282,8 +284,6 @@ class CustomFlaxBartModule(FlaxBartModule):
282
  # the decoder has a different config
283
  decoder_config = BartConfig(self.config.to_dict())
284
  decoder_config.max_position_embeddings = OUTPUT_LENGTH
285
- decoder_config.min_length = OUTPUT_LENGTH
286
- decoder_config.max_length = OUTPUT_LENGTH
287
  decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
288
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
289
 
@@ -363,7 +363,7 @@ def main():
363
  else:
364
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
365
 
366
- logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
367
  training_args.eval_steps = 400
368
 
369
  if (
@@ -412,7 +412,7 @@ def main():
412
  # (the dataset will be downloaded automatically from the datasets Hub).
413
  #
414
  data_files = {}
415
- logger.warning(f"Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
416
  if data_args.train_file is not None:
417
  data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv"]
418
  if data_args.validation_file is not None:
@@ -434,14 +434,15 @@ def main():
434
  # Set up our new model config
435
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
436
  config.tie_word_embeddings = False
437
- config.decoder_start_token_id = BOS_TOKEN_ID
438
- config.bos_token_id = BOS_TOKEN_ID # should not be used
439
  config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
440
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
441
  config.forced_bos_token_id = None # we don't need this token
442
  config.forced_eos_token_id = None # we don't need this token
443
- #config.min_length = data_args.max_target_length # Set only in decoder?
444
- #config.max_length = data_args.max_target_length # Set only in decoder?
 
445
 
446
  print(f"TPUs: {jax.device_count()}")
447
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
@@ -779,7 +780,7 @@ def main():
779
 
780
  return eval_metrics
781
 
782
- def run_save_model(step, epoch, eval_metrics=None):
783
  if jax.process_index() == 0:
784
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
 
@@ -789,6 +790,13 @@ def main():
789
  params=params,
790
  )
791
 
 
 
 
 
 
 
 
792
  # save to W&B
793
  if data_args.log_model:
794
  metadata = {'step': step, 'epoch': epoch}
@@ -799,6 +807,8 @@ def main():
799
  )
800
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
801
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
 
 
802
  wandb.run.log_artifact(artifact)
803
 
804
  # save to the hub
@@ -835,7 +845,7 @@ def main():
835
  run_evaluation()
836
 
837
  if global_step % data_args.save_model_steps == 0:
838
- run_save_model(global_step, epoch)
839
 
840
  # log final train metrics
841
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
@@ -850,7 +860,7 @@ def main():
850
  eval_metrics = run_evaluation()
851
 
852
  # save checkpoint after each epoch and push checkpoint to the hub
853
- run_save_model(global_step, epoch, eval_metrics)
854
 
855
 
856
  # ======================== Prediction loop ==============================
 
31
  from functools import partial
32
  from pathlib import Path
33
  from typing import Callable, Optional
34
+ import json
35
 
36
  import datasets
37
  import nltk # Here to have a nice missing dependency error message early on
 
45
  import transformers
46
  from filelock import FileLock
47
  from flax import jax_utils, traverse_util
48
+ from flax.serialization import from_bytes, to_bytes
49
  import flax.linen as nn
50
  from flax.jax_utils import unreplicate
51
  from flax.training import train_state
 
284
  # the decoder has a different config
285
  decoder_config = BartConfig(self.config.to_dict())
286
  decoder_config.max_position_embeddings = OUTPUT_LENGTH
 
 
287
  decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
288
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
289
 
 
363
  else:
364
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
365
 
366
+ logger.warning(f"WARNING: eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
367
  training_args.eval_steps = 400
368
 
369
  if (
 
412
  # (the dataset will be downloaded automatically from the datasets Hub).
413
  #
414
  data_files = {}
415
+ logger.warning(f"WARNING: Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
416
  if data_args.train_file is not None:
417
  data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv"]
418
  if data_args.validation_file is not None:
 
434
  # Set up our new model config
435
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
436
  config.tie_word_embeddings = False
437
+ config.decoder_start_token_id = BOS_TOKEN_ID # for first token
438
+ config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
439
  config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
440
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
441
  config.forced_bos_token_id = None # we don't need this token
442
  config.forced_eos_token_id = None # we don't need this token
443
+ config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
444
+ config.min_length = data_args.max_target_length
445
+ config.max_length = data_args.max_target_length
446
 
447
  print(f"TPUs: {jax.device_count()}")
448
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
780
 
781
  return eval_metrics
782
 
783
+ def run_save_model(state, step, epoch, eval_metrics=None):
784
  if jax.process_index() == 0:
785
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
786
 
 
790
  params=params,
791
  )
792
 
793
+ # save state
794
+ state = unreplicate(state)
795
+ with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
796
+ f.write(to_bytes(state.opt_state))
797
+ with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
798
+ json.dump({'step': state.step.item()}, f)
799
+
800
  # save to W&B
801
  if data_args.log_model:
802
  metadata = {'step': step, 'epoch': epoch}
 
807
  )
808
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
809
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
810
+ artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
811
+ artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
812
  wandb.run.log_artifact(artifact)
813
 
814
  # save to the hub
 
845
  run_evaluation()
846
 
847
  if global_step % data_args.save_model_steps == 0:
848
+ run_save_model(state, global_step, epoch)
849
 
850
  # log final train metrics
851
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
 
860
  eval_metrics = run_evaluation()
861
 
862
  # save checkpoint after each epoch and push checkpoint to the hub
863
+ run_save_model(state, global_step, epoch, eval_metrics)
864
 
865
 
866
  # ======================== Prediction loop ==============================