acul3 commited on
Commit
f68dce5
1 Parent(s): 1a45e65

add stream

Browse files
Files changed (1) hide show
  1. run_mlm_flax_stream.py +11 -28
run_mlm_flax_stream.py CHANGED
@@ -308,7 +308,7 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
308
  while i < num_total_tokens:
309
  tokenized_samples = next(train_iterator)
310
  i += len(tokenized_samples["input_ids"])
311
- print(tokenized_samples)
312
  # concatenate tokenized samples to list
313
  samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
314
 
@@ -451,30 +451,13 @@ if __name__ == "__main__":
451
  # 'text' is found. You can easily tweak this behavior (see below).
452
  if data_args.dataset_name is not None:
453
  # Downloading and loading a dataset from the hub.
454
- filepaths = {}
455
- if data_args.train_file:
456
- filepaths["train"] = data_args.train_file
457
- if data_args.validation_file:
458
- filepaths["validation"] = data_args.validation_file
459
- try:
460
- dataset = load_dataset(
461
- data_args.dataset_name,
462
- data_args.dataset_config_name,
463
- cache_dir=model_args.cache_dir,
464
- streaming=True,
465
- split="train",
466
- )
467
- except Exception as exc:
468
- logger.warning(
469
- f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
470
- )
471
- dataset = load_dataset(
472
- data_args.dataset_name,
473
- data_args.dataset_config_name,
474
- cache_dir=model_args.cache_dir,
475
- streaming=True,
476
- split="train",
477
- )
478
 
479
  if model_args.config_name:
480
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
@@ -505,13 +488,13 @@ if __name__ == "__main__":
505
  return tokenizer(
506
  examples[data_args.text_column_name],
507
  max_length=512,
508
- truncation=True,
509
  return_special_tokens_mask=True
510
  )
511
 
512
  tokenized_datasets = dataset.map(
513
  tokenize_function,
514
  batched=True,
 
515
  )
516
 
517
  shuffle_seed = training_args.seed
@@ -524,8 +507,8 @@ if __name__ == "__main__":
524
  # Enable Weight&Biases
525
  import wandb
526
  wandb.init(
527
- entity='munggok',
528
- project='roberta-indo-base',
529
  sync_tensorboard=True,
530
  )
531
  wandb.config.update(training_args)
 
308
  while i < num_total_tokens:
309
  tokenized_samples = next(train_iterator)
310
  i += len(tokenized_samples["input_ids"])
311
+
312
  # concatenate tokenized samples to list
313
  samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
314
 
 
451
  # 'text' is found. You can easily tweak this behavior (see below).
452
  if data_args.dataset_name is not None:
453
  # Downloading and loading a dataset from the hub.
454
+ dataset = load_dataset(
455
+ data_args.dataset_name,
456
+ data_args.dataset_config_name,
457
+ cache_dir=model_args.cache_dir,
458
+ streaming=True,
459
+ split="train",
460
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
  if model_args.config_name:
463
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
 
488
  return tokenizer(
489
  examples[data_args.text_column_name],
490
  max_length=512,
 
491
  return_special_tokens_mask=True
492
  )
493
 
494
  tokenized_datasets = dataset.map(
495
  tokenize_function,
496
  batched=True,
497
+ remove_columns=list(dataset.features.keys()),
498
  )
499
 
500
  shuffle_seed = training_args.seed
 
507
  # Enable Weight&Biases
508
  import wandb
509
  wandb.init(
510
+ entity='wandb',
511
+ project='hf-flax-bertin-roberta-es',
512
  sync_tensorboard=True,
513
  )
514
  wandb.config.update(training_args)