xzyao's picture
Upload folder using huggingface_hub
a91ef18 verified
raw
history blame
No virus
2.8 kB
import io
import json
import os
from glob import glob
import datasets
import zstandard as zstd
from datasets import GeneratorBasedBuilder
from datasets.utils import Version
from huggingface_hub import snapshot_download
# Requires REPO_NAME and file name to be same e.g. uspto.py
REPO_NAME = "Multi-Domain-Expert-Layers/github"
class PileDomainDataset(GeneratorBasedBuilder):
VERSION = Version("1.0.0")
def _info(self):
return datasets.DatasetInfo(
description="Pile Domain Dataset",
features=datasets.Features(
{
"text": datasets.Value("string"),
}
),
supervised_keys=None,
)
def _split_generators(self, dl_manager):
dl_path = snapshot_download(repo_id=REPO_NAME, repo_type="dataset")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/train"),
"split": None,
},
),
datasets.SplitGenerator(
name="validation",
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/val"),
"split": None,
},
),
datasets.SplitGenerator(
name="validation_pile",
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/val"),
"split": "pile",
},
),
datasets.SplitGenerator(
name="validation_domain",
gen_kwargs={
"data_dir": os.path.join(dl_path, "data/val"),
"split": "domain",
},
),
datasets.SplitGenerator(
name="test_pile",
gen_kwargs={"data_dir": os.path.join(dl_path, "data/test"), "split": "pile"},
),
datasets.SplitGenerator(
name="test_domain",
gen_kwargs={"data_dir": os.path.join(dl_path, "data/test"), "split": "domain"},
),
]
def _generate_examples(self, data_dir, split):
dctx = zstd.ZstdDecompressor()
idx = -1
file_paths = glob(os.path.join(data_dir, f"*.jsonl.zst"))
if split is not None:
file_paths = [f for f in file_paths if split in f]
for file in file_paths:
with open(file, "rb") as f:
reader = dctx.stream_reader(f)
buffer = io.BufferedReader(reader)
for _, line in enumerate(buffer.readlines()):
data = json.loads(line)
idx += 1
yield idx, data