boris commited on
Commit
eac6890
1 Parent(s): 85748ef

feat: use_auth_token + seed for dataset and model

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +35 -12
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -129,6 +129,12 @@ class DataTrainingArguments:
129
  default=False,
130
  metadata={"help": "Whether to stream the dataset."},
131
  )
 
 
 
 
 
 
132
  max_source_length: Optional[int] = field(
133
  default=128,
134
  metadata={
@@ -256,9 +262,18 @@ class TrainingArguments:
256
  metadata={"help": "Log model to wandb at `save_steps` frequency."},
257
  )
258
 
259
- seed: int = field(
260
  default=42,
261
- metadata={"help": "Random seed that will be set at the beginning of training."},
 
 
 
 
 
 
 
 
 
262
  )
263
 
264
  push_to_hub: bool = field(
@@ -304,7 +319,9 @@ class TrainState(train_state.TrainState):
304
 
305
 
306
  def data_loader(
307
- rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
 
 
308
  ):
309
  """
310
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
@@ -312,7 +329,7 @@ def data_loader(
312
  """
313
  steps_per_epoch = len(dataset) // batch_size
314
 
315
- if shuffle:
316
  batch_idx = jax.random.permutation(rng, len(dataset))
317
  else:
318
  batch_idx = jnp.arange(len(dataset))
@@ -432,6 +449,7 @@ def main():
432
  data_args.dataset_repo_or_path,
433
  data_files=data_files,
434
  streaming=data_args.streaming,
 
435
  )
436
 
437
  # Set up wandb run
@@ -483,7 +501,7 @@ def main():
483
 
484
  # Create a custom model and initialize it randomly
485
  model = CustomFlaxBartForConditionalGeneration(
486
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
487
  )
488
 
489
  # Load tokenizer
@@ -561,7 +579,14 @@ def main():
561
  else train_dataset.select(range(data_args.max_train_samples))
562
  )
563
  if data_args.streaming:
564
- train_dataset = train_dataset.shuffle(1000, training_args.seed)
 
 
 
 
 
 
 
565
  if model.config.normalize_text:
566
  train_dataset = (
567
  train_dataset.map(normalize_text)
@@ -627,7 +652,7 @@ def main():
627
  )
628
 
629
  # Initialize our training
630
- rng = jax.random.PRNGKey(training_args.seed)
631
  rng, dropout_rng = jax.random.split(rng)
632
 
633
  # Store some constant
@@ -808,7 +833,7 @@ def main():
808
  if data_args.streaming:
809
  eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
810
  else:
811
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
812
  eval_steps = (
813
  len_eval_dataset // eval_batch_size
814
  if len_eval_dataset is not None
@@ -927,10 +952,8 @@ def main():
927
  train_dataset.set_epoch(epoch) # shuffle dataset
928
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
929
  else:
930
- rng, input_rng = jax.random.split(rng)
931
- train_loader = data_loader(
932
- input_rng, train_dataset, train_batch_size, shuffle=True
933
- )
934
  # train
935
  for batch in tqdm(
936
  train_loader,
 
129
  default=False,
130
  metadata={"help": "Whether to stream the dataset."},
131
  )
132
+ use_auth_token: bool = field(
133
+ default=False,
134
+ metadata={
135
+ "help": "Whether to use the authentication token for private datasets."
136
+ },
137
+ )
138
  max_source_length: Optional[int] = field(
139
  default=128,
140
  metadata={
 
262
  metadata={"help": "Log model to wandb at `save_steps` frequency."},
263
  )
264
 
265
+ seed_model: int = field(
266
  default=42,
267
+ metadata={
268
+ "help": "Random seed for the model that will be set at the beginning of training."
269
+ },
270
+ )
271
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
272
+ seed_dataset: int = field(
273
+ default=None,
274
+ metadata={
275
+ "help": "Random seed for the dataset that will be set at the beginning of training."
276
+ },
277
  )
278
 
279
  push_to_hub: bool = field(
 
319
 
320
 
321
  def data_loader(
322
+ dataset: Dataset,
323
+ batch_size: int,
324
+ rng: jax.random.PRNGKey = None,
325
  ):
326
  """
327
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
 
329
  """
330
  steps_per_epoch = len(dataset) // batch_size
331
 
332
+ if rng is not None:
333
  batch_idx = jax.random.permutation(rng, len(dataset))
334
  else:
335
  batch_idx = jnp.arange(len(dataset))
 
449
  data_args.dataset_repo_or_path,
450
  data_files=data_files,
451
  streaming=data_args.streaming,
452
+ use_auth_token=data_args.use_auth_token,
453
  )
454
 
455
  # Set up wandb run
 
501
 
502
  # Create a custom model and initialize it randomly
503
  model = CustomFlaxBartForConditionalGeneration(
504
+ config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype)
505
  )
506
 
507
  # Load tokenizer
 
579
  else train_dataset.select(range(data_args.max_train_samples))
580
  )
581
  if data_args.streaming:
582
+ train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
583
+ else:
584
+ seed_dataset = (
585
+ training_args.seed_dataset
586
+ if training_args.seed_dataset is not None
587
+ else np.random.get_state()[1][0]
588
+ )
589
+ rng_dataset = jax.random.PRNGKey(seed_dataset)
590
  if model.config.normalize_text:
591
  train_dataset = (
592
  train_dataset.map(normalize_text)
 
652
  )
653
 
654
  # Initialize our training
655
+ rng = jax.random.PRNGKey(training_args.seed_model)
656
  rng, dropout_rng = jax.random.split(rng)
657
 
658
  # Store some constant
 
833
  if data_args.streaming:
834
  eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
835
  else:
836
+ eval_loader = data_loader(eval_dataset, eval_batch_size)
837
  eval_steps = (
838
  len_eval_dataset // eval_batch_size
839
  if len_eval_dataset is not None
 
952
  train_dataset.set_epoch(epoch) # shuffle dataset
953
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
954
  else:
955
+ rng_dataset, input_rng = jax.random.split(rng_dataset)
956
+ train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
 
 
957
  # train
958
  for batch in tqdm(
959
  train_loader,