xzyao's picture
Upload folder using huggingface_hub
9fb22af verified
import io
import json
import os
from glob import glob
import datasets
import zstandard as zstd
from datasets import GeneratorBasedBuilder
from datasets.utils import Version
from huggingface_hub import snapshot_download
# Requires REPO_NAME and file name to be same e.g. uspto.py
REPO_NAME = "Multi-Domain-Expert-Layers/arxiv"
class PileDomainDataset(GeneratorBasedBuilder):
VERSION = Version("1.0.0")
def _info(self):
return datasets.DatasetInfo(
description="Pile Domain Dataset",
features=datasets.Features(
{
"text": datasets.Value("string"),
}
),
supervised_keys=None,
)
def _split_generators(self, dl_manager):
dl_path = snapshot_download(repo_id=REPO_NAME, repo_type="dataset")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/train"),
"split": None,
},
),
datasets.SplitGenerator(
name="validation",
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/val"),
"split": None,
},
),
datasets.SplitGenerator(
name="validation_pile",
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/val"),
"split": "pile",
},
),
datasets.SplitGenerator(
name="validation_domain",
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/val"),
"split": "domain",
},
),
datasets.SplitGenerator(
name="test_pile",
gen_kwargs={"data_dir": os.path.join(dl_path, "data/test"), "split": "pile"},
),
datasets.SplitGenerator(
name="test_domain",
gen_kwargs={"data_dir": os.path.join(dl_path, "data/test"), "split": "domain"},
),
]
def _generate_examples(self, data_dir, split):
dctx = zstd.ZstdDecompressor()
idx = -1
file_paths = glob(os.path.join(data_dir, f"*.jsonl.zst"))
if split is not None:
file_paths = [f for f in file_paths if split in f]
for file in file_paths:
with open(file, "rb") as f:
reader = dctx.stream_reader(f)
buffer = io.BufferedReader(reader)
for _, line in enumerate(buffer.readlines()):
data = json.loads(line)
idx += 1
yield idx, data