versae commited on
Commit
daff9ab
1 Parent(s): 641b942

Fixed a couple of conditonals

Browse files
Files changed (2) hide show
  1. mc4/mc4.py +2 -2
  2. run_mlm_flax_stream.py +1 -1
mc4/mc4.py CHANGED
@@ -376,13 +376,13 @@ class Mc4(datasets.GeneratorBasedBuilder):
376
  for lang in self.config.languages
377
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
378
  ]
379
- if "train" in self.data_files:
380
  train_downloaded_files = self.data_files["train"]
381
  if not isinstance(train_downloaded_files, (tuple, list)):
382
  train_downloaded_files = [train_downloaded_files]
383
  else:
384
  train_downloaded_files = dl_manager.download(data_urls["train"])
385
- if "validation" in self.data_files:
386
  validation_downloaded_files = self.data_files["validation"]
387
  if not isinstance(validation_downloaded_files, (tuple, list)):
388
  validation_downloaded_files = [validation_downloaded_files]
376
  for lang in self.config.languages
377
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
378
  ]
379
+ if self.data_files and "train" in self.data_files:
380
  train_downloaded_files = self.data_files["train"]
381
  if not isinstance(train_downloaded_files, (tuple, list)):
382
  train_downloaded_files = [train_downloaded_files]
383
  else:
384
  train_downloaded_files = dl_manager.download(data_urls["train"])
385
+ if self.data_files and "validation" in self.data_files:
386
  validation_downloaded_files = self.data_files["validation"]
387
  if not isinstance(validation_downloaded_files, (tuple, list)):
388
  validation_downloaded_files = [validation_downloaded_files]
run_mlm_flax_stream.py CHANGED
@@ -588,7 +588,7 @@ if __name__ == "__main__":
588
  # Setup train state
589
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
590
  saved_step = 0
591
- if "checkpoint" in model_args.model_name_or_path:
592
  params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
593
  # Create learning rate schedule
594
  warmup_fn = optax.linear_schedule(
588
  # Setup train state
589
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
590
  saved_step = 0
591
+ if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path:
592
  params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
593
  # Create learning rate schedule
594
  warmup_fn = optax.linear_schedule(