Adding reading streaming files from local disk
Browse files- mc4/mc4.py +36 -14
- run_mlm_flax_stream.py +8 -2
mc4/mc4.py
CHANGED
@@ -283,6 +283,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
284 |
|
285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
|
|
286 |
self.sampling_method = kwargs.pop("sampling_method", None)
|
287 |
if self.sampling_method:
|
288 |
seed = kwargs.pop("seed", None)
|
@@ -290,19 +291,20 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
290 |
self.rng = default_rng(seed)
|
291 |
else:
|
292 |
self.rng = default_rng()
|
293 |
-
self.
|
294 |
-
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
295 |
-
self.boundaries = kwargs.pop("boundaries", None)
|
296 |
-
# Loading 5-gram model
|
297 |
-
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
298 |
-
logger.info("loading model = %s", self.perplexity_model)
|
299 |
-
self.pp_model = kenlm.Model(self.perplexity_model)
|
300 |
-
if self.sampling_method == "gaussian":
|
301 |
-
self.should_keep_doc = self._should_keep_doc_gaussian
|
302 |
-
elif self.sampling_method == "random":
|
303 |
self.should_keep_doc = self._should_keep_doc_random
|
304 |
else:
|
305 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
308 |
|
@@ -341,7 +343,9 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
341 |
return self.rng.uniform() < weighted_perplexity
|
342 |
|
343 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
344 |
-
|
|
|
|
|
345 |
|
346 |
def _info(self):
|
347 |
return datasets.DatasetInfo(
|
@@ -371,8 +375,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
371 |
for lang in self.config.languages
|
372 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
373 |
]
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
return [
|
377 |
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
|
378 |
datasets.SplitGenerator(
|
@@ -385,6 +399,14 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
385 |
id_ = 0
|
386 |
for filepath in filepaths:
|
387 |
logger.info("generating examples from = %s", filepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
|
389 |
if self.sampling_method:
|
390 |
logger.info("sampling method = %s", self.sampling_method)
|
|
|
283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
284 |
|
285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
286 |
+
self.filepaths = kwargs.pop(filepaths, {})
|
287 |
self.sampling_method = kwargs.pop("sampling_method", None)
|
288 |
if self.sampling_method:
|
289 |
seed = kwargs.pop("seed", None)
|
|
|
291 |
self.rng = default_rng(seed)
|
292 |
else:
|
293 |
self.rng = default_rng()
|
294 |
+
if self.sampling_method == "random":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
self.should_keep_doc = self._should_keep_doc_random
|
296 |
else:
|
297 |
+
self.perplexity_model = kwargs.pop("perplexity_model", None)
|
298 |
+
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
299 |
+
self.boundaries = kwargs.pop("boundaries", None)
|
300 |
+
# Loading 5-gram model
|
301 |
+
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
302 |
+
logger.info("loading model = %s", self.perplexity_model)
|
303 |
+
self.pp_model = kenlm.Model(self.perplexity_model)
|
304 |
+
if self.sampling_method == "gaussian":
|
305 |
+
self.should_keep_doc = self._should_keep_doc_gaussian
|
306 |
+
else:
|
307 |
+
self.should_keep_doc = self._should_keep_doc_step
|
308 |
|
309 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
310 |
|
|
|
343 |
return self.rng.uniform() < weighted_perplexity
|
344 |
|
345 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
346 |
+
if factor is None:
|
347 |
+
factor = 0.5
|
348 |
+
return self.rng.uniform() <= factor
|
349 |
|
350 |
def _info(self):
|
351 |
return datasets.DatasetInfo(
|
|
|
375 |
for lang in self.config.languages
|
376 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
377 |
]
|
378 |
+
if "train" in self.filepaths:
|
379 |
+
train_downloaded_files = self.filepaths["train"]
|
380 |
+
if not isinstance(train_downloaded_files, (tuple, list)):
|
381 |
+
train_downloaded_files = [train_downloaded_files]
|
382 |
+
else:
|
383 |
+
train_downloaded_files = dl_manager.download(data_urls["train"])
|
384 |
+
if "validation" in self.filepaths:
|
385 |
+
validation_downloaded_files = self.filepaths["validation"]
|
386 |
+
if not isinstance(validation_downloaded_files, (tuple, list)):
|
387 |
+
validation_downloaded_files = [validation_downloaded_files]
|
388 |
+
else:
|
389 |
+
validation_downloaded_files = dl_manager.download(data_urls["validation"])
|
390 |
return [
|
391 |
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
|
392 |
datasets.SplitGenerator(
|
|
|
399 |
id_ = 0
|
400 |
for filepath in filepaths:
|
401 |
logger.info("generating examples from = %s", filepath)
|
402 |
+
if filepath.endswith("json") or filepath.endswith("jsonl"):
|
403 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
404 |
+
for line in f:
|
405 |
+
if line:
|
406 |
+
example = json.loads(line)
|
407 |
+
yield id_, example
|
408 |
+
id_ += 1
|
409 |
+
else:
|
410 |
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
|
411 |
if self.sampling_method:
|
412 |
logger.info("sampling method = %s", self.sampling_method)
|
run_mlm_flax_stream.py
CHANGED
@@ -178,10 +178,10 @@ class DataTrainingArguments:
|
|
178 |
else:
|
179 |
if self.train_file is not None:
|
180 |
extension = self.train_file.split(".")[-1]
|
181 |
-
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
182 |
if self.validation_file is not None:
|
183 |
extension = self.validation_file.split(".")[-1]
|
184 |
-
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
185 |
|
186 |
|
187 |
@flax.struct.dataclass
|
@@ -386,6 +386,11 @@ if __name__ == "__main__":
|
|
386 |
# 'text' is found. You can easily tweak this behavior (see below).
|
387 |
if data_args.dataset_name is not None:
|
388 |
# Downloading and loading a dataset from the hub.
|
|
|
|
|
|
|
|
|
|
|
389 |
dataset = load_dataset(
|
390 |
data_args.dataset_name,
|
391 |
data_args.dataset_config_name,
|
@@ -397,6 +402,7 @@ if __name__ == "__main__":
|
|
397 |
boundaries=sampling_args.boundaries,
|
398 |
perplexity_model=sampling_args.perplexity_model,
|
399 |
seed=training_args.seed,
|
|
|
400 |
)
|
401 |
|
402 |
if model_args.config_name:
|
|
|
178 |
else:
|
179 |
if self.train_file is not None:
|
180 |
extension = self.train_file.split(".")[-1]
|
181 |
+
assert extension in ["csv", "json", "txt", "gz"], "`train_file` should be a csv, a json or a txt file."
|
182 |
if self.validation_file is not None:
|
183 |
extension = self.validation_file.split(".")[-1]
|
184 |
+
assert extension in ["csv", "json", "txt", "gz"], "`validation_file` should be a csv, a json or a txt file."
|
185 |
|
186 |
|
187 |
@flax.struct.dataclass
|
|
|
386 |
# 'text' is found. You can easily tweak this behavior (see below).
|
387 |
if data_args.dataset_name is not None:
|
388 |
# Downloading and loading a dataset from the hub.
|
389 |
+
filepaths = {}
|
390 |
+
if data_args.train_file:
|
391 |
+
filepaths["train"] = data_args.train_file
|
392 |
+
if data_args.validation_file:
|
393 |
+
filepaths["validation"] = data_args.validation_file
|
394 |
dataset = load_dataset(
|
395 |
data_args.dataset_name,
|
396 |
data_args.dataset_config_name,
|
|
|
402 |
boundaries=sampling_args.boundaries,
|
403 |
perplexity_model=sampling_args.perplexity_model,
|
404 |
seed=training_args.seed,
|
405 |
+
filepaths=filepaths,
|
406 |
)
|
407 |
|
408 |
if model_args.config_name:
|