Fixed a couple of conditonals
Browse files- mc4/mc4.py +2 -2
- run_mlm_flax_stream.py +1 -1
mc4/mc4.py
CHANGED
@@ -376,13 +376,13 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
376 |
for lang in self.config.languages
|
377 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
378 |
]
|
379 |
-
if "train" in self.data_files:
|
380 |
train_downloaded_files = self.data_files["train"]
|
381 |
if not isinstance(train_downloaded_files, (tuple, list)):
|
382 |
train_downloaded_files = [train_downloaded_files]
|
383 |
else:
|
384 |
train_downloaded_files = dl_manager.download(data_urls["train"])
|
385 |
-
if "validation" in self.data_files:
|
386 |
validation_downloaded_files = self.data_files["validation"]
|
387 |
if not isinstance(validation_downloaded_files, (tuple, list)):
|
388 |
validation_downloaded_files = [validation_downloaded_files]
|
|
|
376 |
for lang in self.config.languages
|
377 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
378 |
]
|
379 |
+
if self.data_files and "train" in self.data_files:
|
380 |
train_downloaded_files = self.data_files["train"]
|
381 |
if not isinstance(train_downloaded_files, (tuple, list)):
|
382 |
train_downloaded_files = [train_downloaded_files]
|
383 |
else:
|
384 |
train_downloaded_files = dl_manager.download(data_urls["train"])
|
385 |
+
if self.data_files and "validation" in self.data_files:
|
386 |
validation_downloaded_files = self.data_files["validation"]
|
387 |
if not isinstance(validation_downloaded_files, (tuple, list)):
|
388 |
validation_downloaded_files = [validation_downloaded_files]
|
run_mlm_flax_stream.py
CHANGED
@@ -588,7 +588,7 @@ if __name__ == "__main__":
|
|
588 |
# Setup train state
|
589 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
590 |
saved_step = 0
|
591 |
-
if "checkpoint" in model_args.model_name_or_path:
|
592 |
params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
|
593 |
# Create learning rate schedule
|
594 |
warmup_fn = optax.linear_schedule(
|
|
|
588 |
# Setup train state
|
589 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
590 |
saved_step = 0
|
591 |
+
if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path:
|
592 |
params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
|
593 |
# Create learning rate schedule
|
594 |
warmup_fn = optax.linear_schedule(
|