Fixes and defaults
Browse files- mc4/mc4.py +18 -13
- run_mlm_flax_stream.py +8 -8
mc4/mc4.py
CHANGED
@@ -8,7 +8,6 @@ import datasets
|
|
8 |
import kenlm
|
9 |
import numpy as np
|
10 |
from numpy.random import default_rng
|
11 |
-
rng = default_rng()
|
12 |
|
13 |
|
14 |
logger = datasets.logging.get_logger(__name__)
|
@@ -284,11 +283,16 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
284 |
BUILDER_CONFIG_CLASS = Mc4Config
|
285 |
|
286 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
287 |
-
self.sampling_method = kwargs.pop("sampling_method")
|
288 |
if self.sampling_method:
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
292 |
# Loading 5-gram model
|
293 |
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
294 |
logger.info("loading model = %s", self.perplexity_model)
|
@@ -298,7 +302,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
298 |
elif self.sampling_method == "random":
|
299 |
self.should_keep_doc = self._should_keep_doc_random
|
300 |
else:
|
301 |
-
self.should_keep_doc = self.
|
302 |
|
303 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
304 |
|
@@ -311,7 +315,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
311 |
doc_length += length
|
312 |
return 10.0 ** (-doc_log_score / doc_length)
|
313 |
|
314 |
-
def _should_keep_doc_step(self, doc, factor=1, boundaries=None):
|
315 |
perplexity = self.get_perplexity(doc)
|
316 |
if boundaries is None:
|
317 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
@@ -322,21 +326,22 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
322 |
elif boundaries[1] < perplexity < boundaries[2]:
|
323 |
quartile_range = boundaries[2] - boundaries[1]
|
324 |
elif perplexity >= boundaries[2]:
|
325 |
-
quartile_range =
|
326 |
probability = factor / quartile_range
|
327 |
-
return rng.uniform() < probability
|
328 |
|
329 |
-
def _should_keep_doc_gaussian(self, doc, factor=0.
|
330 |
perplexity = self.get_perplexity(doc)
|
331 |
if boundaries is not None:
|
332 |
m = boundaries[1]
|
333 |
else:
|
334 |
m = 662247.50212365
|
335 |
-
|
336 |
-
|
|
|
337 |
|
338 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
339 |
-
return rng.uniform() <= 0.5
|
340 |
|
341 |
def _info(self):
|
342 |
return datasets.DatasetInfo(
|
|
|
8 |
import kenlm
|
9 |
import numpy as np
|
10 |
from numpy.random import default_rng
|
|
|
11 |
|
12 |
|
13 |
logger = datasets.logging.get_logger(__name__)
|
|
|
283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
284 |
|
285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
286 |
+
self.sampling_method = kwargs.pop("sampling_method", None)
|
287 |
if self.sampling_method:
|
288 |
+
seed = kwargs.pop("seed", None)
|
289 |
+
if seed is not None:
|
290 |
+
self.rng = default_rng(seed)
|
291 |
+
else:
|
292 |
+
self.rng = default_rng()
|
293 |
+
self.perplexity_model = kwargs.pop("perplexity_model", None)
|
294 |
+
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
295 |
+
self.boundaries = kwargs.pop("boundaries", None)
|
296 |
# Loading 5-gram model
|
297 |
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
298 |
logger.info("loading model = %s", self.perplexity_model)
|
|
|
302 |
elif self.sampling_method == "random":
|
303 |
self.should_keep_doc = self._should_keep_doc_random
|
304 |
else:
|
305 |
+
self.should_keep_doc = self._should_keep_doc_step
|
306 |
|
307 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
308 |
|
|
|
315 |
doc_length += length
|
316 |
return 10.0 ** (-doc_log_score / doc_length)
|
317 |
|
318 |
+
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
|
319 |
perplexity = self.get_perplexity(doc)
|
320 |
if boundaries is None:
|
321 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
|
|
326 |
elif boundaries[1] < perplexity < boundaries[2]:
|
327 |
quartile_range = boundaries[2] - boundaries[1]
|
328 |
elif perplexity >= boundaries[2]:
|
329 |
+
quartile_range = 10 * boundaries[2]
|
330 |
probability = factor / quartile_range
|
331 |
+
return self.rng.uniform() < probability
|
332 |
|
333 |
+
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
|
334 |
perplexity = self.get_perplexity(doc)
|
335 |
if boundaries is not None:
|
336 |
m = boundaries[1]
|
337 |
else:
|
338 |
m = 662247.50212365
|
339 |
+
exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2)
|
340 |
+
weighted_perplexity = factor * exponential
|
341 |
+
return self.rng.uniform() < weighted_perplexity
|
342 |
|
343 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
344 |
+
return self.rng.uniform() <= 0.5
|
345 |
|
346 |
def _info(self):
|
347 |
return datasets.DatasetInfo(
|
run_mlm_flax_stream.py
CHANGED
@@ -256,28 +256,27 @@ class FlaxDataCollatorForLanguageModeling:
|
|
256 |
return inputs, labels
|
257 |
|
258 |
|
259 |
-
|
260 |
@dataclass
|
261 |
class SamplingArguments:
|
262 |
"""
|
263 |
Arguments pertaining to how to perform sampling of the dataset.
|
264 |
"""
|
265 |
|
266 |
-
perplexity_model: Optional[str]
|
267 |
-
default="es.arpa.bin", metadata={"help": "
|
268 |
)
|
269 |
-
sampling_method: Optional[str]
|
270 |
-
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document."}
|
271 |
)
|
272 |
-
sampling_factor: Optional[
|
273 |
-
default=
|
274 |
)
|
275 |
boundaries: Optional[str] = field(
|
276 |
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
277 |
)
|
278 |
|
279 |
def __post_init__(self):
|
280 |
-
self.boundaries = [float(q) for q in self.boundaries.split(",")]
|
281 |
|
282 |
|
283 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
@@ -397,6 +396,7 @@ if __name__ == "__main__":
|
|
397 |
sampling_factor=sampling_args.sampling_factor,
|
398 |
boundaries=sampling_args.boundaries,
|
399 |
perplexity_model=sampling_args.perplexity_model,
|
|
|
400 |
)
|
401 |
|
402 |
if model_args.config_name:
|
|
|
256 |
return inputs, labels
|
257 |
|
258 |
|
|
|
259 |
@dataclass
|
260 |
class SamplingArguments:
|
261 |
"""
|
262 |
Arguments pertaining to how to perform sampling of the dataset.
|
263 |
"""
|
264 |
|
265 |
+
perplexity_model: Optional[str] = field(
|
266 |
+
default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
|
267 |
)
|
268 |
+
sampling_method: Optional[str] = field(
|
269 |
+
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
|
270 |
)
|
271 |
+
sampling_factor: Optional[float] = field(
|
272 |
+
default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
|
273 |
)
|
274 |
boundaries: Optional[str] = field(
|
275 |
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
276 |
)
|
277 |
|
278 |
def __post_init__(self):
|
279 |
+
self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
|
280 |
|
281 |
|
282 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
|
|
396 |
sampling_factor=sampling_args.sampling_factor,
|
397 |
boundaries=sampling_args.boundaries,
|
398 |
perplexity_model=sampling_args.perplexity_model,
|
399 |
+
seed=training_args.seed,
|
400 |
)
|
401 |
|
402 |
if model_args.config_name:
|