versae commited on
Commit
8bd9e95
1 Parent(s): 7b22f12

Fixes to mc4 fork

Browse files
Files changed (2) hide show
  1. mc4/mc4.py +11 -12
  2. run_mlm_flax_stream.py +1 -1
mc4/mc4.py CHANGED
@@ -283,20 +283,20 @@ class Mc4(datasets.GeneratorBasedBuilder):
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
286
- self.filepaths = kwargs.pop(filepaths, {})
287
  self.sampling_method = kwargs.pop("sampling_method", None)
 
 
 
 
288
  if self.sampling_method:
289
- seed = kwargs.pop("seed", None)
290
- if seed is not None:
291
- self.rng = default_rng(seed)
292
  else:
293
  self.rng = default_rng()
294
  if self.sampling_method == "random":
295
  self.should_keep_doc = self._should_keep_doc_random
296
  else:
297
- self.perplexity_model = kwargs.pop("perplexity_model", None)
298
- self.sampling_factor = kwargs.pop("sampling_factor", None)
299
- self.boundaries = kwargs.pop("boundaries", None)
300
  # Loading 5-gram model
301
  # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
302
  logger.info("loading model = %s", self.perplexity_model)
@@ -305,7 +305,6 @@ class Mc4(datasets.GeneratorBasedBuilder):
305
  self.should_keep_doc = self._should_keep_doc_gaussian
306
  else:
307
  self.should_keep_doc = self._should_keep_doc_step
308
-
309
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
310
 
311
  def get_perplexity(self, doc):
@@ -375,14 +374,14 @@ class Mc4(datasets.GeneratorBasedBuilder):
375
  for lang in self.config.languages
376
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
377
  ]
378
- if "train" in self.filepaths:
379
- train_downloaded_files = self.filepaths["train"]
380
  if not isinstance(train_downloaded_files, (tuple, list)):
381
  train_downloaded_files = [train_downloaded_files]
382
  else:
383
  train_downloaded_files = dl_manager.download(data_urls["train"])
384
- if "validation" in self.filepaths:
385
- validation_downloaded_files = self.filepaths["validation"]
386
  if not isinstance(validation_downloaded_files, (tuple, list)):
387
  validation_downloaded_files = [validation_downloaded_files]
388
  else:
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
286
+ self.data_files = kwargs.pop("data_files", {})
287
  self.sampling_method = kwargs.pop("sampling_method", None)
288
+ self.perplexity_model = kwargs.pop("perplexity_model", None)
289
+ self.sampling_factor = kwargs.pop("sampling_factor", None)
290
+ self.boundaries = kwargs.pop("boundaries", None)
291
+ self.seed = kwargs.pop("seed", None)
292
  if self.sampling_method:
293
+ if self.seed is not None:
294
+ self.rng = default_rng(self.seed)
 
295
  else:
296
  self.rng = default_rng()
297
  if self.sampling_method == "random":
298
  self.should_keep_doc = self._should_keep_doc_random
299
  else:
 
 
 
300
  # Loading 5-gram model
301
  # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
302
  logger.info("loading model = %s", self.perplexity_model)
305
  self.should_keep_doc = self._should_keep_doc_gaussian
306
  else:
307
  self.should_keep_doc = self._should_keep_doc_step
 
308
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
309
 
310
  def get_perplexity(self, doc):
374
  for lang in self.config.languages
375
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
376
  ]
377
+ if "train" in self.data_files:
378
+ train_downloaded_files = self.data_files["train"]
379
  if not isinstance(train_downloaded_files, (tuple, list)):
380
  train_downloaded_files = [train_downloaded_files]
381
  else:
382
  train_downloaded_files = dl_manager.download(data_urls["train"])
383
+ if "validation" in self.data_files:
384
+ validation_downloaded_files = self.data_files["validation"]
385
  if not isinstance(validation_downloaded_files, (tuple, list)):
386
  validation_downloaded_files = [validation_downloaded_files]
387
  else:
run_mlm_flax_stream.py CHANGED
@@ -402,7 +402,7 @@ if __name__ == "__main__":
402
  boundaries=sampling_args.boundaries,
403
  perplexity_model=sampling_args.perplexity_model,
404
  seed=training_args.seed,
405
- filepaths={"train": filepaths},
406
  )
407
 
408
  if model_args.config_name:
402
  boundaries=sampling_args.boundaries,
403
  perplexity_model=sampling_args.perplexity_model,
404
  seed=training_args.seed,
405
+ data_files=filepaths,
406
  )
407
 
408
  if model_args.config_name: