add stream
Browse files- run_mlm_flax_stream.py +11 -28
run_mlm_flax_stream.py
CHANGED
@@ -308,7 +308,7 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
|
308 |
while i < num_total_tokens:
|
309 |
tokenized_samples = next(train_iterator)
|
310 |
i += len(tokenized_samples["input_ids"])
|
311 |
-
|
312 |
# concatenate tokenized samples to list
|
313 |
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
314 |
|
@@ -451,30 +451,13 @@ if __name__ == "__main__":
|
|
451 |
# 'text' is found. You can easily tweak this behavior (see below).
|
452 |
if data_args.dataset_name is not None:
|
453 |
# Downloading and loading a dataset from the hub.
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
data_args.dataset_name,
|
462 |
-
data_args.dataset_config_name,
|
463 |
-
cache_dir=model_args.cache_dir,
|
464 |
-
streaming=True,
|
465 |
-
split="train",
|
466 |
-
)
|
467 |
-
except Exception as exc:
|
468 |
-
logger.warning(
|
469 |
-
f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
|
470 |
-
)
|
471 |
-
dataset = load_dataset(
|
472 |
-
data_args.dataset_name,
|
473 |
-
data_args.dataset_config_name,
|
474 |
-
cache_dir=model_args.cache_dir,
|
475 |
-
streaming=True,
|
476 |
-
split="train",
|
477 |
-
)
|
478 |
|
479 |
if model_args.config_name:
|
480 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
@@ -505,13 +488,13 @@ if __name__ == "__main__":
|
|
505 |
return tokenizer(
|
506 |
examples[data_args.text_column_name],
|
507 |
max_length=512,
|
508 |
-
truncation=True,
|
509 |
return_special_tokens_mask=True
|
510 |
)
|
511 |
|
512 |
tokenized_datasets = dataset.map(
|
513 |
tokenize_function,
|
514 |
batched=True,
|
|
|
515 |
)
|
516 |
|
517 |
shuffle_seed = training_args.seed
|
@@ -524,8 +507,8 @@ if __name__ == "__main__":
|
|
524 |
# Enable Weight&Biases
|
525 |
import wandb
|
526 |
wandb.init(
|
527 |
-
entity='
|
528 |
-
project='roberta-
|
529 |
sync_tensorboard=True,
|
530 |
)
|
531 |
wandb.config.update(training_args)
|
|
|
308 |
while i < num_total_tokens:
|
309 |
tokenized_samples = next(train_iterator)
|
310 |
i += len(tokenized_samples["input_ids"])
|
311 |
+
|
312 |
# concatenate tokenized samples to list
|
313 |
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
314 |
|
|
|
451 |
# 'text' is found. You can easily tweak this behavior (see below).
|
452 |
if data_args.dataset_name is not None:
|
453 |
# Downloading and loading a dataset from the hub.
|
454 |
+
dataset = load_dataset(
|
455 |
+
data_args.dataset_name,
|
456 |
+
data_args.dataset_config_name,
|
457 |
+
cache_dir=model_args.cache_dir,
|
458 |
+
streaming=True,
|
459 |
+
split="train",
|
460 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
if model_args.config_name:
|
463 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
|
|
488 |
return tokenizer(
|
489 |
examples[data_args.text_column_name],
|
490 |
max_length=512,
|
|
|
491 |
return_special_tokens_mask=True
|
492 |
)
|
493 |
|
494 |
tokenized_datasets = dataset.map(
|
495 |
tokenize_function,
|
496 |
batched=True,
|
497 |
+
remove_columns=list(dataset.features.keys()),
|
498 |
)
|
499 |
|
500 |
shuffle_seed = training_args.seed
|
|
|
507 |
# Enable Weight&Biases
|
508 |
import wandb
|
509 |
wandb.init(
|
510 |
+
entity='wandb',
|
511 |
+
project='hf-flax-bertin-roberta-es',
|
512 |
sync_tensorboard=True,
|
513 |
)
|
514 |
wandb.config.update(training_args)
|