vulnops / scripts /prepare_mlx_data.py
Adhitya-Vardhan
Initial commit: VulnOps OpenEnv benchmark
d63a1ba
"""Prepare MLX-LM-compatible train/valid files from existing SFT data."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List
from transformers import AutoTokenizer
ROOT = Path(__file__).resolve().parents[1]
TRUNCATION_MARKER = "\n...[truncated observation]...\n"
def load_jsonl(path: Path) -> List[Dict[str, object]]:
with path.open("r", encoding="utf-8") as handle:
return [json.loads(line) for line in handle if line.strip()]
def dump_jsonl(path: Path, rows: List[Dict[str, object]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, sort_keys=True) + "\n")
def trim_prompt_to_budget(prompt: str, tokenizer, budget: int) -> str:
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(prompt_ids) <= budget:
return prompt
marker_ids = tokenizer.encode(TRUNCATION_MARKER, add_special_tokens=False)
marker_len = len(marker_ids)
if budget <= marker_len + 8:
return tokenizer.decode(prompt_ids[-budget:])
remaining = budget - marker_len
head_len = max(1, int(remaining * 0.55))
tail_len = max(1, remaining - head_len)
trimmed_ids = prompt_ids[:head_len] + marker_ids + prompt_ids[-tail_len:]
if len(trimmed_ids) > budget:
trimmed_ids = trimmed_ids[:budget]
return tokenizer.decode(trimmed_ids, skip_special_tokens=False)
def rendered_length(prompt: str, completion: str, tokenizer) -> int:
messages = [
{"role": "user", "content": prompt},
{"role": "assistant", "content": completion},
]
return len(tokenizer.apply_chat_template(messages, return_dict=False))
def normalize_record(record: Dict[str, object], tokenizer, max_seq_length: int) -> tuple[Dict[str, object] | None, Dict[str, int]]:
prompt = str(record["prompt"])
completion = str(record["completion"])
stats = {"trimmed": 0, "dropped": 0}
completion_ids = tokenizer.encode(completion, add_special_tokens=False)
prompt_budget = max_seq_length - len(completion_ids) - 32
if prompt_budget <= 0:
stats["dropped"] = 1
return None, stats
normalized_prompt = trim_prompt_to_budget(prompt, tokenizer, prompt_budget)
while rendered_length(normalized_prompt, completion, tokenizer) > max_seq_length and prompt_budget > 64:
prompt_budget = max(64, int(prompt_budget * 0.9))
normalized_prompt = trim_prompt_to_budget(prompt, tokenizer, prompt_budget)
if rendered_length(normalized_prompt, completion, tokenizer) > max_seq_length:
stats["dropped"] = 1
return None, stats
if normalized_prompt != prompt:
stats["trimmed"] = 1
text = f"{normalized_prompt}\n{completion}"
normalized = dict(record)
normalized["prompt"] = normalized_prompt
normalized["text"] = text
return normalized, stats
def transform_split(src: Path, dst: Path, tokenizer, max_seq_length: int) -> Dict[str, int]:
rows = load_jsonl(src)
normalized_rows: List[Dict[str, object]] = []
stats = {"input_examples": len(rows), "written_examples": 0, "trimmed_examples": 0, "dropped_examples": 0}
for row in rows:
normalized, row_stats = normalize_record(row, tokenizer, max_seq_length)
stats["trimmed_examples"] += row_stats["trimmed"]
stats["dropped_examples"] += row_stats["dropped"]
if normalized is not None:
normalized_rows.append(normalized)
stats["written_examples"] = len(normalized_rows)
dump_jsonl(dst, normalized_rows)
return stats
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--source-root", default="artifacts/lora_qwen3_4b/data")
parser.add_argument("--output-root", default="artifacts/mlx_qwen3_4b/data")
parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--include-valid", action="store_true")
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
source_root = (ROOT / args.source_root).resolve()
output_root = (ROOT / args.output_root).resolve()
output_root.mkdir(parents=True, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
mapping = {source_root / "train.jsonl": output_root / "train.jsonl"}
if args.include_valid:
mapping[source_root / "eval.jsonl"] = output_root / "valid.jsonl"
summary: Dict[str, object] = {
"model": args.model,
"max_seq_length": args.max_seq_length,
"splits": {},
}
for src, dst in mapping.items():
if not src.exists():
raise FileNotFoundError(f"Missing source file: {src}")
if dst.exists() and not args.force:
continue
summary["splits"][dst.stem] = transform_split(src, dst, tokenizer, args.max_seq_length)
valid_path = output_root / "valid.jsonl"
if not args.include_valid and valid_path.exists():
valid_path.unlink()
summary_path = output_root.parent / "prepare_stats.json"
summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
print(output_root)
print(json.dumps(summary, indent=2, sort_keys=True))
if __name__ == "__main__":
main()