versae commited on
Commit
a5b19d7
1 Parent(s): f562f06

Fixes and defaults

Browse files
Files changed (2) hide show
  1. mc4/mc4.py +18 -13
  2. run_mlm_flax_stream.py +8 -8
mc4/mc4.py CHANGED
@@ -8,7 +8,6 @@ import datasets
8
  import kenlm
9
  import numpy as np
10
  from numpy.random import default_rng
11
- rng = default_rng()
12
 
13
 
14
  logger = datasets.logging.get_logger(__name__)
@@ -284,11 +283,16 @@ class Mc4(datasets.GeneratorBasedBuilder):
284
  BUILDER_CONFIG_CLASS = Mc4Config
285
 
286
  def __init__(self, *args, writer_batch_size=None, **kwargs):
287
- self.sampling_method = kwargs.pop("sampling_method")
288
  if self.sampling_method:
289
- self.perplexity_model = kwargs.pop("perplexity_model")
290
- self.sampling_factor = kwargs.pop("sampling_factor")
291
- self.boundaries = kwargs.pop("boundaries")
 
 
 
 
 
292
  # Loading 5-gram model
293
  # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
294
  logger.info("loading model = %s", self.perplexity_model)
@@ -298,7 +302,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
298
  elif self.sampling_method == "random":
299
  self.should_keep_doc = self._should_keep_doc_random
300
  else:
301
- self.should_keep_doc = self._should_keep_doc_gaussian
302
 
303
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
304
 
@@ -311,7 +315,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
311
  doc_length += length
312
  return 10.0 ** (-doc_log_score / doc_length)
313
 
314
- def _should_keep_doc_step(self, doc, factor=1, boundaries=None):
315
  perplexity = self.get_perplexity(doc)
316
  if boundaries is None:
317
  boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
@@ -322,21 +326,22 @@ class Mc4(datasets.GeneratorBasedBuilder):
322
  elif boundaries[1] < perplexity < boundaries[2]:
323
  quartile_range = boundaries[2] - boundaries[1]
324
  elif perplexity >= boundaries[2]:
325
- quartile_range = 100 * boundaries[2]
326
  probability = factor / quartile_range
327
- return rng.uniform() < probability
328
 
329
- def _should_keep_doc_gaussian(self, doc, factor=0.4, boundaries=None):
330
  perplexity = self.get_perplexity(doc)
331
  if boundaries is not None:
332
  m = boundaries[1]
333
  else:
334
  m = 662247.50212365
335
- weighted_perplexity = factor * np.exp(-9/2*((perplexity-m)/m)**2)
336
- return rng.uniform() < weighted_perplexity
 
337
 
338
  def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
339
- return rng.uniform() <= 0.5
340
 
341
  def _info(self):
342
  return datasets.DatasetInfo(
 
8
  import kenlm
9
  import numpy as np
10
  from numpy.random import default_rng
 
11
 
12
 
13
  logger = datasets.logging.get_logger(__name__)
 
283
  BUILDER_CONFIG_CLASS = Mc4Config
284
 
285
  def __init__(self, *args, writer_batch_size=None, **kwargs):
286
+ self.sampling_method = kwargs.pop("sampling_method", None)
287
  if self.sampling_method:
288
+ seed = kwargs.pop("seed", None)
289
+ if seed is not None:
290
+ self.rng = default_rng(seed)
291
+ else:
292
+ self.rng = default_rng()
293
+ self.perplexity_model = kwargs.pop("perplexity_model", None)
294
+ self.sampling_factor = kwargs.pop("sampling_factor", None)
295
+ self.boundaries = kwargs.pop("boundaries", None)
296
  # Loading 5-gram model
297
  # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
298
  logger.info("loading model = %s", self.perplexity_model)
 
302
  elif self.sampling_method == "random":
303
  self.should_keep_doc = self._should_keep_doc_random
304
  else:
305
+ self.should_keep_doc = self._should_keep_doc_step
306
 
307
  super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
308
 
 
315
  doc_length += length
316
  return 10.0 ** (-doc_log_score / doc_length)
317
 
318
+ def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
319
  perplexity = self.get_perplexity(doc)
320
  if boundaries is None:
321
  boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
 
326
  elif boundaries[1] < perplexity < boundaries[2]:
327
  quartile_range = boundaries[2] - boundaries[1]
328
  elif perplexity >= boundaries[2]:
329
+ quartile_range = 10 * boundaries[2]
330
  probability = factor / quartile_range
331
+ return self.rng.uniform() < probability
332
 
333
+ def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
334
  perplexity = self.get_perplexity(doc)
335
  if boundaries is not None:
336
  m = boundaries[1]
337
  else:
338
  m = 662247.50212365
339
+ exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2)
340
+ weighted_perplexity = factor * exponential
341
+ return self.rng.uniform() < weighted_perplexity
342
 
343
  def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
344
+ return self.rng.uniform() <= 0.5
345
 
346
  def _info(self):
347
  return datasets.DatasetInfo(
run_mlm_flax_stream.py CHANGED
@@ -256,28 +256,27 @@ class FlaxDataCollatorForLanguageModeling:
256
  return inputs, labels
257
 
258
 
259
-
260
  @dataclass
261
  class SamplingArguments:
262
  """
263
  Arguments pertaining to how to perform sampling of the dataset.
264
  """
265
 
266
- perplexity_model: Optional[str] = field(
267
- default="es.arpa.bin", metadata={"help": "kenlm model to use to get perplexity values."}
268
  )
269
- sampling_method: Optional[str] = field(
270
- default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document."}
271
  )
272
- sampling_factor: Optional[int] = field(
273
- default=1, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
274
  )
275
  boundaries: Optional[str] = field(
276
  default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
277
  )
278
 
279
  def __post_init__(self):
280
- self.boundaries = [float(q) for q in self.boundaries.split(",")]
281
 
282
 
283
  def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
@@ -397,6 +396,7 @@ if __name__ == "__main__":
397
  sampling_factor=sampling_args.sampling_factor,
398
  boundaries=sampling_args.boundaries,
399
  perplexity_model=sampling_args.perplexity_model,
 
400
  )
401
 
402
  if model_args.config_name:
 
256
  return inputs, labels
257
 
258
 
 
259
  @dataclass
260
  class SamplingArguments:
261
  """
262
  Arguments pertaining to how to perform sampling of the dataset.
263
  """
264
 
265
+ perplexity_model: Optional[str] = field(
266
+ default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
267
  )
268
+ sampling_method: Optional[str] = field(
269
+ default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
270
  )
271
+ sampling_factor: Optional[float] = field(
272
+ default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
273
  )
274
  boundaries: Optional[str] = field(
275
  default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
276
  )
277
 
278
  def __post_init__(self):
279
+ self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
280
 
281
 
282
  def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
 
396
  sampling_factor=sampling_args.sampling_factor,
397
  boundaries=sampling_args.boundaries,
398
  perplexity_model=sampling_args.perplexity_model,
399
+ seed=training_args.seed,
400
  )
401
 
402
  if model_args.config_name: