versae commited on
Commit
60b6f6b
1 Parent(s): 79555ba

Adding random sampling

Browse files
Files changed (1) hide show
  1. mc4/mc4.py +5 -0
mc4/mc4.py CHANGED
@@ -293,6 +293,8 @@ class Mc4(datasets.GeneratorBasedBuilder):
293
  self.pp_model = kenlm.Model(self.perplexity_model)
294
  if self.sampling_method == "gaussian":
295
  self.should_keep_doc = self._should_keep_doc_gaussian
 
 
296
  else:
297
  self.should_keep_doc = self._should_keep_doc_gaussian
298
 
@@ -332,6 +334,9 @@ class Mc4(datasets.GeneratorBasedBuilder):
332
  weighted_perplexity = factor * np.exp(-9/2*((perplexity-m)/m)**2)
333
  return np.random.uniform() < weighted_perplexity
334
 
 
 
 
335
  def _info(self):
336
  return datasets.DatasetInfo(
337
  description=_DESCRIPTION,
 
293
  self.pp_model = kenlm.Model(self.perplexity_model)
294
  if self.sampling_method == "gaussian":
295
  self.should_keep_doc = self._should_keep_doc_gaussian
296
+ elif self.sampling_method == "random":
297
+ self.should_keep_doc = self._should_keep_doc_random
298
  else:
299
  self.should_keep_doc = self._should_keep_doc_gaussian
300
 
 
334
  weighted_perplexity = factor * np.exp(-9/2*((perplexity-m)/m)**2)
335
  return np.random.uniform() < weighted_perplexity
336
 
337
+ def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
338
+ return np.random() <= 0.5
339
+
340
  def _info(self):
341
  return datasets.DatasetInfo(
342
  description=_DESCRIPTION,