import io import json import os from glob import glob import datasets import zstandard as zstd from datasets import GeneratorBasedBuilder from datasets.utils import Version from huggingface_hub import snapshot_download # Requires REPO_NAME and file name to be same e.g. uspto.py REPO_NAME = "Multi-Domain-Expert-Layers/uspto" class PileDomainDataset(GeneratorBasedBuilder): VERSION = Version("1.0.0") def _info(self): return datasets.DatasetInfo( description="Pile Domain Dataset", features=datasets.Features( { "text": datasets.Value("string"), } ), supervised_keys=None, ) def _split_generators(self, dl_manager): dl_path = snapshot_download(repo_id=REPO_NAME, repo_type="dataset") return [ datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={ "data_dir": os.path.join(dl_path, "data/train"), "split": None, }, ), datasets.SplitGenerator( name="validation", gen_kwargs={ "data_dir": os.path.join(dl_path, "data/val"), "split": None, }, ), datasets.SplitGenerator( name="validation_pile", gen_kwargs={ "data_dir": os.path.join(dl_path, "data/val"), "split": "pile", }, ), datasets.SplitGenerator( name="validation_domain", gen_kwargs={ "data_dir": os.path.join(dl_path, "data/val"), "split": "domain", }, ), datasets.SplitGenerator( name="test_pile", gen_kwargs={"data_dir": os.path.join(dl_path, "data/test"), "split": "pile"}, ), datasets.SplitGenerator( name="test_domain", gen_kwargs={"data_dir": os.path.join(dl_path, "data/test"), "split": "domain"}, ), ] def _generate_examples(self, data_dir, split): dctx = zstd.ZstdDecompressor() idx = -1 file_paths = glob(os.path.join(data_dir, f"*.jsonl.zst")) if split is not None: file_paths = [f for f in file_paths if split in f] for file in file_paths: with open(file, "rb") as f: reader = dctx.stream_reader(f) buffer = io.BufferedReader(reader) for _, line in enumerate(buffer.readlines()): data = json.loads(line) idx += 1 yield idx, data