versae commited on
Commit
4e4228c
1 Parent(s): a5b19d7

Adding reading streaming files from local disk

Browse files
Files changed (2) hide show
  1. mc4/mc4.py +36 -14
  2. run_mlm_flax_stream.py +8 -2
mc4/mc4.py CHANGED
@@ -283,6 +283,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
 
286
  self.sampling_method = kwargs.pop("sampling_method", None)
287
  if self.sampling_method:
288
  seed = kwargs.pop("seed", None)
@@ -290,19 +291,20 @@ class Mc4(datasets.GeneratorBasedBuilder):
290
  self.rng = default_rng(seed)
291
  else:
292
  self.rng = default_rng()
293
- self.perplexity_model = kwargs.pop("perplexity_model", None)
294
- self.sampling_factor = kwargs.pop("sampling_factor", None)
295
- self.boundaries = kwargs.pop("boundaries", None)
296
- # Loading 5-gram model
297
- # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
298
- logger.info("loading model = %s", self.perplexity_model)
299
- self.pp_model = kenlm.Model(self.perplexity_model)
300
- if self.sampling_method == "gaussian":
301
- self.should_keep_doc = self._should_keep_doc_gaussian
302
- elif self.sampling_method == "random":
303
  self.should_keep_doc = self._should_keep_doc_random
304
  else:
305
- self.should_keep_doc = self._should_keep_doc_step
 
 
 
 
 
 
 
 
 
 
306
 
307
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
308
 
@@ -341,7 +343,9 @@ class Mc4(datasets.GeneratorBasedBuilder):
341
  return self.rng.uniform() < weighted_perplexity
342
 
343
  def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
344
- return self.rng.uniform() <= 0.5
 
 
345
 
346
  def _info(self):
347
  return datasets.DatasetInfo(
@@ -371,8 +375,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
371
  for lang in self.config.languages
372
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
373
  ]
374
- train_downloaded_files = dl_manager.download(data_urls["train"])
375
- validation_downloaded_files = dl_manager.download(data_urls["validation"])
 
 
 
 
 
 
 
 
 
 
376
  return [
377
  datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
378
  datasets.SplitGenerator(
@@ -385,6 +399,14 @@ class Mc4(datasets.GeneratorBasedBuilder):
385
  id_ = 0
386
  for filepath in filepaths:
387
  logger.info("generating examples from = %s", filepath)
 
 
 
 
 
 
 
 
388
  with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
389
  if self.sampling_method:
390
  logger.info("sampling method = %s", self.sampling_method)
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
286
+ self.filepaths = kwargs.pop(filepaths, {})
287
  self.sampling_method = kwargs.pop("sampling_method", None)
288
  if self.sampling_method:
289
  seed = kwargs.pop("seed", None)
291
  self.rng = default_rng(seed)
292
  else:
293
  self.rng = default_rng()
294
+ if self.sampling_method == "random":
 
 
 
 
 
 
 
 
 
295
  self.should_keep_doc = self._should_keep_doc_random
296
  else:
297
+ self.perplexity_model = kwargs.pop("perplexity_model", None)
298
+ self.sampling_factor = kwargs.pop("sampling_factor", None)
299
+ self.boundaries = kwargs.pop("boundaries", None)
300
+ # Loading 5-gram model
301
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
302
+ logger.info("loading model = %s", self.perplexity_model)
303
+ self.pp_model = kenlm.Model(self.perplexity_model)
304
+ if self.sampling_method == "gaussian":
305
+ self.should_keep_doc = self._should_keep_doc_gaussian
306
+ else:
307
+ self.should_keep_doc = self._should_keep_doc_step
308
 
309
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
310
 
343
  return self.rng.uniform() < weighted_perplexity
344
 
345
  def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
346
+ if factor is None:
347
+ factor = 0.5
348
+ return self.rng.uniform() <= factor
349
 
350
  def _info(self):
351
  return datasets.DatasetInfo(
375
  for lang in self.config.languages
376
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
377
  ]
378
+ if "train" in self.filepaths:
379
+ train_downloaded_files = self.filepaths["train"]
380
+ if not isinstance(train_downloaded_files, (tuple, list)):
381
+ train_downloaded_files = [train_downloaded_files]
382
+ else:
383
+ train_downloaded_files = dl_manager.download(data_urls["train"])
384
+ if "validation" in self.filepaths:
385
+ validation_downloaded_files = self.filepaths["validation"]
386
+ if not isinstance(validation_downloaded_files, (tuple, list)):
387
+ validation_downloaded_files = [validation_downloaded_files]
388
+ else:
389
+ validation_downloaded_files = dl_manager.download(data_urls["validation"])
390
  return [
391
  datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
392
  datasets.SplitGenerator(
399
  id_ = 0
400
  for filepath in filepaths:
401
  logger.info("generating examples from = %s", filepath)
402
+ if filepath.endswith("json") or filepath.endswith("jsonl"):
403
+ with open(filepath, "r", encoding="utf-8") as f:
404
+ for line in f:
405
+ if line:
406
+ example = json.loads(line)
407
+ yield id_, example
408
+ id_ += 1
409
+ else:
410
  with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
411
  if self.sampling_method:
412
  logger.info("sampling method = %s", self.sampling_method)
run_mlm_flax_stream.py CHANGED
@@ -178,10 +178,10 @@ class DataTrainingArguments:
178
  else:
179
  if self.train_file is not None:
180
  extension = self.train_file.split(".")[-1]
181
- assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
182
  if self.validation_file is not None:
183
  extension = self.validation_file.split(".")[-1]
184
- assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
185
 
186
 
187
  @flax.struct.dataclass
@@ -386,6 +386,11 @@ if __name__ == "__main__":
386
  # 'text' is found. You can easily tweak this behavior (see below).
387
  if data_args.dataset_name is not None:
388
  # Downloading and loading a dataset from the hub.
 
 
 
 
 
389
  dataset = load_dataset(
390
  data_args.dataset_name,
391
  data_args.dataset_config_name,
@@ -397,6 +402,7 @@ if __name__ == "__main__":
397
  boundaries=sampling_args.boundaries,
398
  perplexity_model=sampling_args.perplexity_model,
399
  seed=training_args.seed,
 
400
  )
401
 
402
  if model_args.config_name:
178
  else:
179
  if self.train_file is not None:
180
  extension = self.train_file.split(".")[-1]
181
+ assert extension in ["csv", "json", "txt", "gz"], "`train_file` should be a csv, a json or a txt file."
182
  if self.validation_file is not None:
183
  extension = self.validation_file.split(".")[-1]
184
+ assert extension in ["csv", "json", "txt", "gz"], "`validation_file` should be a csv, a json or a txt file."
185
 
186
 
187
  @flax.struct.dataclass
386
  # 'text' is found. You can easily tweak this behavior (see below).
387
  if data_args.dataset_name is not None:
388
  # Downloading and loading a dataset from the hub.
389
+ filepaths = {}
390
+ if data_args.train_file:
391
+ filepaths["train"] = data_args.train_file
392
+ if data_args.validation_file:
393
+ filepaths["validation"] = data_args.validation_file
394
  dataset = load_dataset(
395
  data_args.dataset_name,
396
  data_args.dataset_config_name,
402
  boundaries=sampling_args.boundaries,
403
  perplexity_model=sampling_args.perplexity_model,
404
  seed=training_args.seed,
405
+ filepaths=filepaths,
406
  )
407
 
408
  if model_args.config_name: