versae commited on
Commit
61f6971
2 Parent(s): 300e533 8bd9e95

Merge branch 'main' of https://huggingface.co/flax-community/bertin-roberta-large-spanish into main

Browse files
Files changed (2) hide show
  1. mc4/mc4.py +52 -31
  2. run_mlm_flax_stream.py +8 -2
mc4/mc4.py CHANGED
@@ -283,27 +283,28 @@ class Mc4(datasets.GeneratorBasedBuilder):
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
 
286
  self.sampling_method = kwargs.pop("sampling_method", None)
 
 
 
 
287
  if self.sampling_method:
288
- seed = kwargs.pop("seed", None)
289
- if seed is not None:
290
- self.rng = default_rng(seed)
291
  else:
292
  self.rng = default_rng()
293
- self.perplexity_model = kwargs.pop("perplexity_model", None)
294
- self.sampling_factor = kwargs.pop("sampling_factor", None)
295
- self.boundaries = kwargs.pop("boundaries", None)
296
- # Loading 5-gram model
297
- # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
298
- logger.info("loading model = %s", self.perplexity_model)
299
- self.pp_model = kenlm.Model(self.perplexity_model)
300
- if self.sampling_method == "gaussian":
301
- self.should_keep_doc = self._should_keep_doc_gaussian
302
- elif self.sampling_method == "random":
303
  self.should_keep_doc = self._should_keep_doc_random
304
  else:
305
- self.should_keep_doc = self._should_keep_doc_step
306
-
 
 
 
 
 
 
307
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
308
 
309
  def get_perplexity(self, doc):
@@ -341,7 +342,9 @@ class Mc4(datasets.GeneratorBasedBuilder):
341
  return self.rng.uniform() < weighted_perplexity
342
 
343
  def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
344
- return self.rng.uniform() <= 0.5
 
 
345
 
346
  def _info(self):
347
  return datasets.DatasetInfo(
@@ -371,8 +374,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
371
  for lang in self.config.languages
372
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
373
  ]
374
- train_downloaded_files = dl_manager.download(data_urls["train"])
375
- validation_downloaded_files = dl_manager.download(data_urls["validation"])
 
 
 
 
 
 
 
 
 
 
376
  return [
377
  datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
378
  datasets.SplitGenerator(
@@ -385,21 +398,29 @@ class Mc4(datasets.GeneratorBasedBuilder):
385
  id_ = 0
386
  for filepath in filepaths:
387
  logger.info("generating examples from = %s", filepath)
388
- with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
389
- if self.sampling_method:
390
- logger.info("sampling method = %s", self.sampling_method)
391
- for line in f:
392
- if line:
393
- example = json.loads(line)
394
- if self.should_keep_doc(
395
- example["text"],
396
- factor=self.sampling_factor,
397
- boundaries=self.boundaries):
398
- yield id_, example
399
- id_ += 1
400
- else:
401
  for line in f:
402
  if line:
403
  example = json.loads(line)
404
  yield id_, example
405
  id_ += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
286
+ self.data_files = kwargs.pop("data_files", {})
287
  self.sampling_method = kwargs.pop("sampling_method", None)
288
+ self.perplexity_model = kwargs.pop("perplexity_model", None)
289
+ self.sampling_factor = kwargs.pop("sampling_factor", None)
290
+ self.boundaries = kwargs.pop("boundaries", None)
291
+ self.seed = kwargs.pop("seed", None)
292
  if self.sampling_method:
293
+ if self.seed is not None:
294
+ self.rng = default_rng(self.seed)
 
295
  else:
296
  self.rng = default_rng()
297
+ if self.sampling_method == "random":
 
 
 
 
 
 
 
 
 
298
  self.should_keep_doc = self._should_keep_doc_random
299
  else:
300
+ # Loading 5-gram model
301
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
302
+ logger.info("loading model = %s", self.perplexity_model)
303
+ self.pp_model = kenlm.Model(self.perplexity_model)
304
+ if self.sampling_method == "gaussian":
305
+ self.should_keep_doc = self._should_keep_doc_gaussian
306
+ else:
307
+ self.should_keep_doc = self._should_keep_doc_step
308
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
309
 
310
  def get_perplexity(self, doc):
342
  return self.rng.uniform() < weighted_perplexity
343
 
344
  def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
345
+ if factor is None:
346
+ factor = 0.5
347
+ return self.rng.uniform() <= factor
348
 
349
  def _info(self):
350
  return datasets.DatasetInfo(
374
  for lang in self.config.languages
375
  for index in range(_N_SHARDS_PER_SPLIT[lang][split])
376
  ]
377
+ if "train" in self.data_files:
378
+ train_downloaded_files = self.data_files["train"]
379
+ if not isinstance(train_downloaded_files, (tuple, list)):
380
+ train_downloaded_files = [train_downloaded_files]
381
+ else:
382
+ train_downloaded_files = dl_manager.download(data_urls["train"])
383
+ if "validation" in self.data_files:
384
+ validation_downloaded_files = self.data_files["validation"]
385
+ if not isinstance(validation_downloaded_files, (tuple, list)):
386
+ validation_downloaded_files = [validation_downloaded_files]
387
+ else:
388
+ validation_downloaded_files = dl_manager.download(data_urls["validation"])
389
  return [
390
  datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
391
  datasets.SplitGenerator(
398
  id_ = 0
399
  for filepath in filepaths:
400
  logger.info("generating examples from = %s", filepath)
401
+ if filepath.endswith("jsonl"):
402
+ with open(filepath, "r", encoding="utf-8") as f:
 
 
 
 
 
 
 
 
 
 
 
403
  for line in f:
404
  if line:
405
  example = json.loads(line)
406
  yield id_, example
407
  id_ += 1
408
+ else:
409
+ with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
410
+ if self.sampling_method:
411
+ logger.info("sampling method = %s", self.sampling_method)
412
+ for line in f:
413
+ if line:
414
+ example = json.loads(line)
415
+ if self.should_keep_doc(
416
+ example["text"],
417
+ factor=self.sampling_factor,
418
+ boundaries=self.boundaries):
419
+ yield id_, example
420
+ id_ += 1
421
+ else:
422
+ for line in f:
423
+ if line:
424
+ example = json.loads(line)
425
+ yield id_, example
426
+ id_ += 1
run_mlm_flax_stream.py CHANGED
@@ -178,10 +178,10 @@ class DataTrainingArguments:
178
  else:
179
  if self.train_file is not None:
180
  extension = self.train_file.split(".")[-1]
181
- assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
182
  if self.validation_file is not None:
183
  extension = self.validation_file.split(".")[-1]
184
- assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
185
 
186
 
187
  @flax.struct.dataclass
@@ -386,6 +386,11 @@ if __name__ == "__main__":
386
  # 'text' is found. You can easily tweak this behavior (see below).
387
  if data_args.dataset_name is not None:
388
  # Downloading and loading a dataset from the hub.
 
 
 
 
 
389
  dataset = load_dataset(
390
  data_args.dataset_name,
391
  data_args.dataset_config_name,
@@ -397,6 +402,7 @@ if __name__ == "__main__":
397
  boundaries=sampling_args.boundaries,
398
  perplexity_model=sampling_args.perplexity_model,
399
  seed=training_args.seed,
 
400
  )
401
 
402
  if model_args.config_name:
178
  else:
179
  if self.train_file is not None:
180
  extension = self.train_file.split(".")[-1]
181
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`train_file` should be a csv, a json (lines) or a txt file."
182
  if self.validation_file is not None:
183
  extension = self.validation_file.split(".")[-1]
184
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`validation_file` should be a csv, a json (lines) or a txt file."
185
 
186
 
187
  @flax.struct.dataclass
386
  # 'text' is found. You can easily tweak this behavior (see below).
387
  if data_args.dataset_name is not None:
388
  # Downloading and loading a dataset from the hub.
389
+ filepaths = {}
390
+ if data_args.train_file:
391
+ filepaths["train"] = data_args.train_file
392
+ if data_args.validation_file:
393
+ filepaths["validation"] = data_args.validation_file
394
  dataset = load_dataset(
395
  data_args.dataset_name,
396
  data_args.dataset_config_name,
402
  boundaries=sampling_args.boundaries,
403
  perplexity_model=sampling_args.perplexity_model,
404
  seed=training_args.seed,
405
+ data_files=filepaths,
406
  )
407
 
408
  if model_args.config_name: