acul3 commited on
Commit
585036a
1 Parent(s): 2745364

add support for v3-32

Browse files
Files changed (1) hide show
  1. run_mlm_flax_stream.py +5 -25
run_mlm_flax_stream.py CHANGED
@@ -262,29 +262,6 @@ class FlaxDataCollatorForLanguageModeling:
262
  return inputs, labels
263
 
264
 
265
- @dataclass
266
- class SamplingArguments:
267
- """
268
- Arguments pertaining to how to perform sampling of the dataset.
269
- """
270
-
271
- perplexity_model: Optional[str] = field(
272
- default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
273
- )
274
- sampling_method: Optional[str] = field(
275
- default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
276
- )
277
- sampling_factor: Optional[float] = field(
278
- default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
279
- )
280
- boundaries: Optional[str] = field(
281
- default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
282
- )
283
-
284
- def __post_init__(self):
285
- self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
286
-
287
-
288
  def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
289
  num_samples = len(samples_idx)
290
  samples_to_remove = num_samples % batch_size
@@ -310,7 +287,9 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
310
  i += len(tokenized_samples["input_ids"])
311
 
312
  # concatenate tokenized samples to list
313
- samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
 
 
314
 
315
  # Concatenated tokens are split to lists of length `max_seq_length`.
316
  # Note that remainedr of % max_seq_length are thrown away.
@@ -404,7 +383,7 @@ if __name__ == "__main__":
404
  # or by passing the --help flag to this script.
405
  # We now keep distinct sets of args, for a cleaner separation of concerns.
406
 
407
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
408
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
409
  # If we pass only one argument to the script and it's the path to a json file,
410
  # let's parse it to get our arguments.
@@ -528,6 +507,7 @@ if __name__ == "__main__":
528
 
529
  # Data collator
530
  # This one will take care of randomly masking the tokens.
 
531
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
532
 
533
  # Initialize our training
 
262
  return inputs, labels
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
266
  num_samples = len(samples_idx)
267
  samples_to_remove = num_samples % batch_size
 
287
  i += len(tokenized_samples["input_ids"])
288
 
289
  # concatenate tokenized samples to list
290
+ samples = {
291
+ k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
292
+ }
293
 
294
  # Concatenated tokens are split to lists of length `max_seq_length`.
295
  # Note that remainedr of % max_seq_length are thrown away.
 
383
  # or by passing the --help flag to this script.
384
  # We now keep distinct sets of args, for a cleaner separation of concerns.
385
 
386
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
387
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
388
  # If we pass only one argument to the script and it's the path to a json file,
389
  # let's parse it to get our arguments.
 
507
 
508
  # Data collator
509
  # This one will take care of randomly masking the tokens.
510
+ print("DATA COLLATOR")
511
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
512
 
513
  # Initialize our training