flashtrace / exp /exp5 /map_exp2_cache_token_spans.py
wenbopan's picture
Sync FlashTrace package from GitHub
55b60a8
#!/usr/bin/env python3
"""Map exp2 cached JSONL token spans across tokenizers (Qwen -> Llama).
Background
----------
`exp/exp2/run_exp.py` expects cached datasets to provide token-level generation spans:
- indices_to_explain: [start_tok, end_tok] (generation-token indices; closed interval)
- sink_span / thinking_span: same tokenizer convention as indices_to_explain
These spans are computed under a specific tokenizer (often Qwen3-8B). When switching
to a different model/tokenizer (e.g., Llama-3.1-8B-Instruct), the stored spans can
become out-of-range and crash exp2 attribution (IndexError in token-span checks).
This script remaps spans by:
1) Tokenizing `target` with the OLD tokenizer to obtain offset_mapping
2) Converting the OLD token span into a character span in `target`
3) Tokenizing `target` with the NEW tokenizer and mapping the character span back
into NEW token indices
Outputs are written under `exp/exp5/data/` by default, keeping `exp/exp2/` untouched.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
from transformers import AutoTokenizer
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def _split_args(values: Iterable[str]) -> List[str]:
out: List[str] = []
for v in values:
for part in str(v).split(","):
part = part.strip()
if part:
out.append(part)
return out
def _load_tokenizer(tokenizer_model: str):
path = Path(tokenizer_model)
if path.exists():
return AutoTokenizer.from_pretrained(path.as_posix(), local_files_only=True)
# May require network access; keep as fallback for environments that allow it.
return AutoTokenizer.from_pretrained(tokenizer_model)
def _is_token_span(span: Any) -> bool:
return (
isinstance(span, list)
and len(span) == 2
and all(isinstance(x, int) for x in span)
and span[0] >= 0
and span[1] >= span[0]
)
def _pick_old_span(obj: Dict[str, Any]) -> Optional[List[int]]:
span = obj.get("indices_to_explain")
if _is_token_span(span):
return list(span)
span = obj.get("sink_span")
if _is_token_span(span):
return list(span)
return None
def _offsets_to_char_span(offsets: Any, token_span: List[int]) -> Optional[Tuple[int, int]]:
"""Convert a token span [start,end] to a character span [char_start,char_end) using offsets."""
if offsets is None:
return None
if not isinstance(offsets, list):
return None
start_tok, end_tok = token_span
if end_tok >= len(offsets):
return None
char_starts: List[int] = []
char_ends: List[int] = []
for idx in range(start_tok, end_tok + 1):
off = offsets[idx]
if off is None:
continue
if not (isinstance(off, (list, tuple)) and len(off) == 2):
continue
try:
s, e = int(off[0]), int(off[1])
except Exception:
continue
if e <= s:
continue
char_starts.append(s)
char_ends.append(e)
if not char_starts or not char_ends:
return None
return min(char_starts), max(char_ends)
def _char_span_to_token_span(offsets: Any, char_span: Tuple[int, int]) -> Optional[List[int]]:
"""Convert a character span [char_start,char_end) to a token span [start,end] by overlap."""
if offsets is None:
return None
if not isinstance(offsets, list):
return None
char_start, char_end = int(char_span[0]), int(char_span[1])
if char_end <= char_start:
return None
hit: List[int] = []
for tok_idx, off in enumerate(offsets):
if off is None:
continue
if not (isinstance(off, (list, tuple)) and len(off) == 2):
continue
try:
s, e = int(off[0]), int(off[1])
except Exception:
continue
if e <= s:
continue
if s < char_end and e > char_start:
hit.append(int(tok_idx))
if not hit:
return None
return [min(hit), max(hit)]
def _validate_span_with_eos(tokenizer, target: str, token_span: List[int]) -> bool:
eos = tokenizer.eos_token or ""
gen_ids = tokenizer(target + eos, add_special_tokens=False).input_ids
gen_len = int(len(gen_ids))
return 0 <= token_span[0] <= token_span[1] < gen_len
def _guess_answer_text(obj: Dict[str, Any]) -> Optional[str]:
meta = obj.get("metadata") or {}
if isinstance(meta, dict):
boxed = (meta.get("boxed_answer") or "").strip()
if boxed:
return boxed
ref = (meta.get("reference_answer") or "").strip()
if ref:
return ref
tgt = obj.get("target")
if isinstance(tgt, str) and tgt.strip():
# Common exp2 cache convention: last line is the final answer.
last_line = tgt.strip().splitlines()[-1].strip()
return last_line or None
return None
def _fallback_map_via_answer_text(
obj: Dict[str, Any],
*,
new_tokenizer,
) -> Optional[List[int]]:
tgt = obj.get("target")
if not isinstance(tgt, str) or not tgt:
return None
from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer # lazy import
answer_text = _guess_answer_text(obj)
ex = CachedExample(
prompt=str(obj.get("prompt") or ""),
target=tgt,
indices_to_explain=None,
attr_mask_indices=obj.get("attr_mask_indices"),
sink_span=None,
thinking_span=None,
metadata=obj.get("metadata") or {},
)
out = attach_spans_from_answer(ex, new_tokenizer, answer_text)
if out.sink_span is None:
return None
if not _is_token_span(out.sink_span):
return None
return list(out.sink_span)
def _map_one_obj(
obj: Dict[str, Any],
*,
old_tokenizer,
new_tokenizer,
allow_fallback_answer: bool,
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
target = obj.get("target")
if not isinstance(target, str) or not target:
return None, "missing_target"
old_span = _pick_old_span(obj)
if old_span is None:
return None, "missing_old_span"
# 1) Old token span -> char span in target.
old_enc = old_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
old_offsets = old_enc.get("offset_mapping")
char_span = _offsets_to_char_span(old_offsets, old_span)
if char_span is None:
if not allow_fallback_answer:
return None, "old_span_to_char_failed"
new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
if new_span is None:
return None, "fallback_answer_failed"
if not _validate_span_with_eos(new_tokenizer, target, new_span):
return None, "fallback_answer_span_invalid"
mapped = dict(obj)
mapped["indices_to_explain"] = new_span
mapped["sink_span"] = new_span
mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
meta = mapped.get("metadata")
if not isinstance(meta, dict):
meta = {}
meta = dict(meta)
meta["exp5_span_map_method"] = "answer_text"
mapped["metadata"] = meta
return mapped, None
# 2) Char span -> new token span.
new_enc = new_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
new_offsets = new_enc.get("offset_mapping")
new_span = _char_span_to_token_span(new_offsets, char_span)
if new_span is None:
if not allow_fallback_answer:
return None, "char_to_new_span_failed"
new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
if new_span is None:
return None, "fallback_answer_failed"
if not _validate_span_with_eos(new_tokenizer, target, new_span):
if not allow_fallback_answer:
return None, "new_span_invalid"
fb = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
if fb is None or not _validate_span_with_eos(new_tokenizer, target, fb):
return None, "fallback_answer_span_invalid"
new_span = fb
mapped = dict(obj)
mapped["indices_to_explain"] = new_span
mapped["sink_span"] = new_span
mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
meta = mapped.get("metadata")
if not isinstance(meta, dict):
meta = {}
meta = dict(meta)
meta["exp5_span_map_method"] = "token_span_char_align"
mapped["metadata"] = meta
return mapped, None
def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
with path.open("r", encoding="utf-8") as f:
for line_no, line in enumerate(f, start=1):
if not line.strip():
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as exc: # pragma: no cover
raise RuntimeError(f"Invalid JSON at {path}:{line_no}: {exc}") from exc
if not isinstance(obj, dict):
raise RuntimeError(f"Expected JSON object per line at {path}:{line_no}.")
yield obj
def _write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
path.parent.mkdir(parents=True, exist_ok=True)
count = 0
with path.open("w", encoding="utf-8") as f:
for obj in rows:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
count += 1
return count
def _default_old_tokenizer() -> str:
# Repo defaults used in exp2 README examples for span extraction.
return "/opt/share/models/Qwen/Qwen3-8B"
def _default_new_tokenizer() -> str:
return "/opt/share/models/meta-llama/Llama-3.1-8B-Instruct"
def main() -> None:
ap = argparse.ArgumentParser("Map exp2 cache token spans from an old tokenizer to a new tokenizer.")
ap.add_argument(
"--in_jsonl",
type=str,
nargs="+",
required=True,
help="One or more exp2 cached JSONL files (comma-separated also accepted).",
)
ap.add_argument(
"--out_dir",
type=str,
default="exp/exp5/data",
help="Output directory for mapped JSONL files.",
)
ap.add_argument(
"--old_tokenizer_model",
type=str,
default=_default_old_tokenizer(),
help="Tokenizer used to produce the original token spans (default: Qwen3-8B local path).",
)
ap.add_argument(
"--new_tokenizer_model",
type=str,
default=_default_new_tokenizer(),
help="Tokenizer to map spans into (default: Llama-3.1-8B-Instruct local path).",
)
ap.add_argument("--strict", action="store_true", help="Fail on the first example that cannot be mapped.")
ap.add_argument(
"--allow_fallback_answer",
action="store_true",
help=(
"If span alignment fails, try to recompute spans by locating metadata.boxed_answer in target "
"(useful when caches were not built with the assumed old tokenizer)."
),
)
ap.add_argument(
"--overwrite",
action="store_true",
help="Overwrite output files if they already exist.",
)
args = ap.parse_args()
in_paths = [Path(p) for p in _split_args(args.in_jsonl)]
out_dir = Path(args.out_dir)
old_tok = _load_tokenizer(str(args.old_tokenizer_model))
new_tok = _load_tokenizer(str(args.new_tokenizer_model))
# exp2 convention: ensure a pad token exists for downstream perturbation.
if new_tok.pad_token is None and new_tok.eos_token is not None:
new_tok.pad_token = new_tok.eos_token
summary: Dict[str, Any] = {
"old_tokenizer_model": str(args.old_tokenizer_model),
"new_tokenizer_model": str(args.new_tokenizer_model),
"datasets": [],
}
for in_path in in_paths:
if not in_path.exists():
raise SystemExit(f"Missing input JSONL: {in_path}")
out_path = out_dir / in_path.name
if out_path.exists() and not bool(args.overwrite):
raise SystemExit(f"Refusing to overwrite existing output: {out_path} (use --overwrite)")
total = 0
mapped_ok = 0
dropped = 0
errors: Dict[str, int] = {}
mapped_rows: List[Dict[str, Any]] = []
for obj in _read_jsonl(in_path):
total += 1
mapped, err = _map_one_obj(
obj,
old_tokenizer=old_tok,
new_tokenizer=new_tok,
allow_fallback_answer=bool(args.allow_fallback_answer),
)
if err is not None or mapped is None:
errors[err or "unknown_error"] = errors.get(err or "unknown_error", 0) + 1
if bool(args.strict):
raise SystemExit(f"Failed to map {in_path} example #{total}: {err}")
dropped += 1
continue
mapped_ok += 1
mapped_rows.append(mapped)
written = _write_jsonl(out_path, mapped_rows)
if written != mapped_ok: # pragma: no cover
raise SystemExit(f"Internal error: written={written} != mapped_ok={mapped_ok}")
record = {
"in_jsonl": str(in_path),
"out_jsonl": str(out_path),
"total": int(total),
"mapped_ok": int(mapped_ok),
"dropped": int(dropped),
"errors": errors,
}
summary["datasets"].append(record)
print(json.dumps(record, ensure_ascii=False))
# Human-readable compact summary at end.
print(json.dumps(summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()