versae's picture
Model at 210k steps, mlm acc 0.6509
d6c5011
raw
history blame
14.1 kB
"""mC4 dataset based on Common Crawl."""
import gzip
import json
import datasets
import kenlm
import numpy as np
from numpy.random import default_rng
logger = datasets.logging.get_logger(__name__)
_DESCRIPTION = """\
A colossal, cleaned version of Common Crawl's web crawl corpus.
Based on Common Crawl dataset: "https://commoncrawl.org".
This is the processed version of Google's mC4 dataset by AllenAI.
"""
_CITATION = """
@article{2019t5,
author = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
journal = {arXiv e-prints},
year = {2019},
archivePrefix = {arXiv},
eprint = {1910.10683},
}
"""
_URL = "https://github.com/allenai/allennlp/discussions/5056"
_DATA_URL = "https://huggingface.co/datasets/allenai/c4/resolve/1ddc917116b730e1859edef32896ec5c16be51d0/multilingual/c4-{language}{split_suffix}.tfrecord-{index:05d}-of-{n_shards:05d}.json.gz"
_LANGUAGES = [
"af",
"am",
"ar",
"az",
"be",
"bg",
"bg-Latn",
"bn",
"ca",
"ceb",
"co",
"cs",
"cy",
"da",
"de",
"el",
"el-Latn",
"en",
"eo",
"es",
"et",
"eu",
"fa",
"fi",
"fil",
"fr",
"fy",
"ga",
"gd",
"gl",
"gu",
"ha",
"haw",
"hi",
"hi-Latn",
"hmn",
"ht",
"hu",
"hy",
"id",
"ig",
"is",
"it",
"iw",
"ja",
"ja-Latn",
"jv",
"ka",
"kk",
"km",
"kn",
"ko",
"ku",
"ky",
"la",
"lb",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"no",
"ny",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"ru-Latn",
"sd",
"si",
"sk",
"sl",
"sm",
"sn",
"so",
"sq",
"sr",
"st",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tr",
"uk",
"und",
"ur",
"uz",
"vi",
"xh",
"yi",
"yo",
"zh",
"zh-Latn",
"zu",
]
_N_SHARDS_PER_SPLIT = {
"af": {"train": 64, "validation": 1},
"am": {"train": 16, "validation": 1},
"ar": {"train": 1024, "validation": 4},
"az": {"train": 256, "validation": 1},
"be": {"train": 128, "validation": 1},
"bg": {"train": 1024, "validation": 1},
"bg-Latn": {"train": 4, "validation": 1},
"bn": {"train": 512, "validation": 1},
"ca": {"train": 512, "validation": 1},
"ceb": {"train": 8, "validation": 1},
"co": {"train": 8, "validation": 1},
"cs": {"train": 1024, "validation": 2},
"cy": {"train": 256, "validation": 1},
"da": {"train": 1024, "validation": 1},
"de": {"train": 2048, "validation": 16},
"el": {"train": 1024, "validation": 2},
"el-Latn": {"train": 16, "validation": 1},
"en": {"train": 11264, "validation": 128},
"eo": {"train": 32, "validation": 1},
"es": {"train": 2048, "validation": 16},
"et": {"train": 256, "validation": 1},
"eu": {"train": 64, "validation": 1},
"fa": {"train": 1024, "validation": 2},
"fi": {"train": 1024, "validation": 1},
"fil": {"train": 64, "validation": 1},
"fr": {"train": 2048, "validation": 16},
"fy": {"train": 16, "validation": 1},
"ga": {"train": 16, "validation": 1},
"gd": {"train": 16, "validation": 1},
"gl": {"train": 128, "validation": 1},
"gu": {"train": 64, "validation": 1},
"ha": {"train": 8, "validation": 1},
"haw": {"train": 2, "validation": 1},
"hi": {"train": 1024, "validation": 2},
"hi-Latn": {"train": 16, "validation": 1},
"hmn": {"train": 8, "validation": 1},
"ht": {"train": 8, "validation": 1},
"hu": {"train": 1024, "validation": 2},
"hy": {"train": 128, "validation": 1},
"id": {"train": 1024, "validation": 4},
"ig": {"train": 4, "validation": 1},
"is": {"train": 128, "validation": 1},
"it": {"train": 1024, "validation": 8},
"iw": {"train": 1024, "validation": 1},
"ja": {"train": 1024, "validation": 8},
"ja-Latn": {"train": 8, "validation": 1},
"jv": {"train": 8, "validation": 1},
"ka": {"train": 256, "validation": 1},
"kk": {"train": 256, "validation": 1},
"km": {"train": 64, "validation": 1},
"kn": {"train": 64, "validation": 1},
"ko": {"train": 1024, "validation": 1},
"ku": {"train": 16, "validation": 1},
"ky": {"train": 64, "validation": 1},
"la": {"train": 64, "validation": 1},
"lb": {"train": 32, "validation": 1},
"lo": {"train": 8, "validation": 1},
"lt": {"train": 512, "validation": 1},
"lv": {"train": 256, "validation": 1},
"mg": {"train": 8, "validation": 1},
"mi": {"train": 4, "validation": 1},
"mk": {"train": 128, "validation": 1},
"ml": {"train": 128, "validation": 1},
"mn": {"train": 128, "validation": 1},
"mr": {"train": 1024, "validation": 1},
"ms": {"train": 512, "validation": 1},
"mt": {"train": 128, "validation": 1},
"my": {"train": 64, "validation": 1},
"ne": {"train": 256, "validation": 1},
"nl": {"train": 1024, "validation": 4},
"no": {"train": 1024, "validation": 1},
"ny": {"train": 4, "validation": 1},
"pa": {"train": 32, "validation": 1},
"pl": {"train": 1024, "validation": 4},
"ps": {"train": 16, "validation": 1},
"pt": {"train": 1024, "validation": 4},
"ro": {"train": 1024, "validation": 2},
"ru": {"train": 4096, "validation": 32},
"ru-Latn": {"train": 32, "validation": 1},
"sd": {"train": 64, "validation": 1},
"si": {"train": 64, "validation": 1},
"sk": {"train": 512, "validation": 1},
"sl": {"train": 256, "validation": 1},
"sm": {"train": 4, "validation": 1},
"sn": {"train": 8, "validation": 1},
"so": {"train": 64, "validation": 1},
"sq": {"train": 128, "validation": 1},
"sr": {"train": 256, "validation": 1},
"st": {"train": 2, "validation": 1},
"su": {"train": 4, "validation": 1},
"sv": {"train": 1024, "validation": 2},
"sw": {"train": 32, "validation": 1},
"ta": {"train": 256, "validation": 1},
"te": {"train": 128, "validation": 1},
"tg": {"train": 64, "validation": 1},
"th": {"train": 1024, "validation": 1},
"tr": {"train": 1024, "validation": 4},
"uk": {"train": 1024, "validation": 2},
"und": {"train": 3072, "validation": 32},
"ur": {"train": 128, "validation": 1},
"uz": {"train": 32, "validation": 1},
"vi": {"train": 1024, "validation": 4},
"xh": {"train": 2, "validation": 1},
"yi": {"train": 16, "validation": 1},
"yo": {"train": 2, "validation": 1},
"zh": {"train": 1024, "validation": 2},
"zh-Latn": {"train": 8, "validation": 1},
"zu": {"train": 8, "validation": 1},
}
class Mc4Config(datasets.BuilderConfig):
"""BuilderConfig for mC4."""
def __init__(self, *args, languages, **kwargs):
"""BuilderConfig for mC4.
Args:
languages (:obj:`List[str]`): list of languages to load
**kwargs: keyword arguments forwarded to super.
"""
super().__init__(
*args,
name="+".join(languages),
**kwargs,
)
self.languages = languages
class Mc4(datasets.GeneratorBasedBuilder):
"""mC4, a colossal, cleaned version of Common Crawl's web crawl corpus."""
BUILDER_CONFIGS = [Mc4Config(languages=[lang]) for lang in _LANGUAGES]
BUILDER_CONFIG_CLASS = Mc4Config
def __init__(self, *args, writer_batch_size=None, **kwargs):
self.data_files = kwargs.pop("data_files", {})
self.sampling_method = kwargs.pop("sampling_method", None)
self.perplexity_model = kwargs.pop("perplexity_model", None)
self.sampling_factor = kwargs.pop("sampling_factor", None)
self.boundaries = kwargs.pop("boundaries", None)
self.seed = kwargs.pop("seed", None)
if self.sampling_method:
if self.seed is not None:
self.rng = default_rng(self.seed)
else:
self.rng = default_rng()
if self.sampling_method == "random":
self.should_keep_doc = self._should_keep_doc_random
else:
# Loading 5-gram model
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
logger.info("loading model = %s", self.perplexity_model)
self.pp_model = kenlm.Model(self.perplexity_model)
if self.sampling_method == "gaussian":
self.should_keep_doc = self._should_keep_doc_gaussian
else:
self.should_keep_doc = self._should_keep_doc_step
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
def get_perplexity(self, doc):
doc_log_score, doc_length = 0, 0
for line in doc.split("\n"):
log_score = self.pp_model.score(line)
length = len(line.split()) + 1
doc_log_score += log_score
doc_length += length
return 10.0 ** (-doc_log_score / doc_length)
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
perplexity = self.get_perplexity(doc)
if boundaries is None:
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
if perplexity <= boundaries[0]:
quartile_range = boundaries[0]
elif boundaries[0] < perplexity < boundaries[1]:
quartile_range = boundaries[1] - boundaries[0]
elif boundaries[1] < perplexity < boundaries[2]:
quartile_range = boundaries[2] - boundaries[1]
elif perplexity >= boundaries[2]:
quartile_range = 10 * boundaries[2]
probability = factor / quartile_range
return self.rng.uniform() < probability
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
perplexity = self.get_perplexity(doc)
if boundaries is not None:
m = boundaries[1]
else:
m = 662247.50212365
exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2)
weighted_perplexity = factor * exponential
return self.rng.uniform() < weighted_perplexity
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
if factor is None:
factor = 0.5
return self.rng.uniform() <= factor
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"text": datasets.Value("string"),
"timestamp": datasets.Value("string"),
"url": datasets.Value("string"),
}
),
supervised_keys=None,
homepage=_URL,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_urls = {}
for split in ["train", "validation"]:
data_urls[split] = [
_DATA_URL.format(
language=self.config.name,
split_suffix="-validation" if split == "validation" else "",
index=index,
n_shards=_N_SHARDS_PER_SPLIT[lang][split],
)
for lang in self.config.languages
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
]
if "train" in self.data_files:
train_downloaded_files = self.data_files["train"]
if not isinstance(train_downloaded_files, (tuple, list)):
train_downloaded_files = [train_downloaded_files]
else:
train_downloaded_files = dl_manager.download(data_urls["train"])
if "validation" in self.data_files:
validation_downloaded_files = self.data_files["validation"]
if not isinstance(validation_downloaded_files, (tuple, list)):
validation_downloaded_files = [validation_downloaded_files]
else:
validation_downloaded_files = dl_manager.download(data_urls["validation"])
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepaths": validation_downloaded_files}
),
]
def _generate_examples(self, filepaths):
"""This function returns the examples in the raw (text) form by iterating on all the files."""
id_ = 0
for filepath in filepaths:
logger.info("generating examples from = %s", filepath)
if filepath.endswith("jsonl"):
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
if line:
example = json.loads(line)
yield id_, example
id_ += 1
else:
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
if self.sampling_method:
logger.info("sampling method = %s", self.sampling_method)
for line in f:
if line:
example = json.loads(line)
if self.should_keep_doc(
example["text"],
factor=self.sampling_factor,
boundaries=self.boundaries):
yield id_, example
id_ += 1
else:
for line in f:
if line:
example = json.loads(line)
yield id_, example
id_ += 1