ds6b-attackplan-qlora / scripts /split_attackplan_jsonl.py
adetuire1's picture
Upload folder using huggingface_hub
fba140f verified
# -*- coding: utf-8 -*-
"""
split_attackplan_jsonl.py
Shuffle and split AttackPlan JSONL into datasets/train|val|test.jsonl
Usage:
%run scripts/split_attackplan_jsonl.py --src "C:/Users/adetu/Dropbox/Ire_Research/my_code/scripts/train_attackplan.jsonl"
"""
from __future__ import annotations
import argparse, json, random, hashlib
from pathlib import Path
def plan_sig(plan: dict) -> str:
# crude signature to detect duplicates across splits
blob = json.dumps(plan.get("plan", []), sort_keys=True)
return hashlib.md5(blob.encode("utf-8")).hexdigest()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--src", type=str, required=True)
ap.add_argument("--seed", type=int, default=7)
ap.add_argument("--train", type=float, default=0.70)
ap.add_argument("--val", type=float, default=0.15)
args = ap.parse_args()
src = Path(args.src)
lines = [ln for ln in src.read_text(encoding="utf-8-sig").splitlines() if ln.strip()]
random.Random(args.seed).shuffle(lines)
n = len(lines)
ntr = int(args.train * n)
nv = int(args.val * n)
test_start = ntr + nv
outdir = Path("datasets"); outdir.mkdir(exist_ok=True)
Path(outdir, "train.jsonl").write_text("\n".join(lines[:ntr]) + "\n", encoding="utf-8")
Path(outdir, "val.jsonl").write_text("\n".join(lines[ntr:test_start]) + "\n", encoding="utf-8")
Path(outdir, "test.jsonl").write_text("\n".join(lines[test_start:]) + "\n", encoding="utf-8")
# duplicate signature report
import json as _json
buckets = {"train":[], "val":[], "test":[]}
for name, chunk in [("train", lines[:ntr]), ("val", lines[ntr:test_start]), ("test", lines[test_start:])]:
for ln in chunk:
try:
obj = _json.loads(ln)
if isinstance(obj, dict) and "plan" in obj:
buckets[name].append(plan_sig(obj))
except Exception:
pass
inter = set(buckets["train"]) & set(buckets["val"]) | set(buckets["train"]) & set(buckets["test"]) | set(buckets["val"]) & set(buckets["test"])
print(f"[done] train/val/test = {ntr}/{nv}/{n - ntr - nv}. duplicate_plans_across_splits={len(inter)}")
if __name__ == "__main__":
main()