|
|
|
"""
|
|
split_attackplan_jsonl.py
|
|
Shuffle and split AttackPlan JSONL into datasets/train|val|test.jsonl
|
|
|
|
Usage:
|
|
%run scripts/split_attackplan_jsonl.py --src "C:/Users/adetu/Dropbox/Ire_Research/my_code/scripts/train_attackplan.jsonl"
|
|
"""
|
|
from __future__ import annotations
|
|
import argparse, json, random, hashlib
|
|
from pathlib import Path
|
|
|
|
def plan_sig(plan: dict) -> str:
|
|
|
|
blob = json.dumps(plan.get("plan", []), sort_keys=True)
|
|
return hashlib.md5(blob.encode("utf-8")).hexdigest()
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--src", type=str, required=True)
|
|
ap.add_argument("--seed", type=int, default=7)
|
|
ap.add_argument("--train", type=float, default=0.70)
|
|
ap.add_argument("--val", type=float, default=0.15)
|
|
args = ap.parse_args()
|
|
|
|
src = Path(args.src)
|
|
lines = [ln for ln in src.read_text(encoding="utf-8-sig").splitlines() if ln.strip()]
|
|
random.Random(args.seed).shuffle(lines)
|
|
|
|
n = len(lines)
|
|
ntr = int(args.train * n)
|
|
nv = int(args.val * n)
|
|
test_start = ntr + nv
|
|
|
|
outdir = Path("datasets"); outdir.mkdir(exist_ok=True)
|
|
Path(outdir, "train.jsonl").write_text("\n".join(lines[:ntr]) + "\n", encoding="utf-8")
|
|
Path(outdir, "val.jsonl").write_text("\n".join(lines[ntr:test_start]) + "\n", encoding="utf-8")
|
|
Path(outdir, "test.jsonl").write_text("\n".join(lines[test_start:]) + "\n", encoding="utf-8")
|
|
|
|
|
|
import json as _json
|
|
buckets = {"train":[], "val":[], "test":[]}
|
|
for name, chunk in [("train", lines[:ntr]), ("val", lines[ntr:test_start]), ("test", lines[test_start:])]:
|
|
for ln in chunk:
|
|
try:
|
|
obj = _json.loads(ln)
|
|
if isinstance(obj, dict) and "plan" in obj:
|
|
buckets[name].append(plan_sig(obj))
|
|
except Exception:
|
|
pass
|
|
inter = set(buckets["train"]) & set(buckets["val"]) | set(buckets["train"]) & set(buckets["test"]) | set(buckets["val"]) & set(buckets["test"])
|
|
print(f"[done] train/val/test = {ntr}/{nv}/{n - ntr - nv}. duplicate_plans_across_splits={len(inter)}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|