| |
|
|
| """ |
| CRUXEval-O dataset: deterministic train/val split + ground-truth execution |
| traces -> (input_ids, labels) for teacher-forcing CODI. |
| |
| Neutral data layer shared by training (``cwm.training.data``) and eval |
| (``evals.cruxeval.run_eval_codi``); depends on nothing in either, so the |
| split and trace format never drift. Thin HuggingFace-tokenizer wrapper over |
| the verbatim Table 9 trace generator (``.ground_truth`` / ``.trace_format``): |
| build the seeded prompt, tokenize ``prompt + render_frames_to_generation(frames)``, |
| and mask the prompt out of the labels (teacher-forced, so labels == input_ids |
| with the prompt prefix set to ``-100``). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from .ground_truth import ground_truth_trace, make_trace_context |
| from .trace_format import ( |
| ACTION_SEP, |
| LINE_SEP, |
| TraceEvent, |
| render_frames_to_generation, |
| ) |
|
|
| IGNORE_INDEX = -100 |
| def _prompt_str(code: str, input_str: str) -> str: |
| ctx = make_trace_context(code, input_str) |
| return f"<|trace_context_start|>{ctx}<|frame_sep|><|call_sep|>{{}}<|action_sep|>def main():\n<|frame_sep|>" |
|
|
|
|
| def _tokenize_trace(code, input_str, tokenizer, *, max_seq_len, max_frames): |
| """``(prompt_ids, trace_ids, spans)``; None to skip. Trace must terminate in |
| RETURN/EXCEPTION and have >=1 LINE span. Span ``(i, j)``: ``trace_ids[i]`` is |
| ``<|line_sep|>``, ``j`` its ``<|action_sep|>``, ``trace_ids[i+1:j]`` the locals |
| a CODI student swaps for a latent block. Single source of membership so the SFT |
| baseline and CODI train on identical data.""" |
| frames, error = ground_truth_trace(code, input_str, align_to_prompt=True, max_frames=max_frames) |
| if not frames or error == "frames_exceeded": |
| return None |
| if frames[-1].event not in (TraceEvent.RETURN, TraceEvent.EXCEPTION): |
| return None |
| |
| bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else [] |
| prompt_ids = bos + tokenizer.encode(_prompt_str(code, input_str), add_special_tokens=False) |
| trace_ids = tokenizer.encode(render_frames_to_generation(frames), add_special_tokens=False) |
| if len(prompt_ids) + len(trace_ids) > max_seq_len: |
| return None |
| ls = tokenizer.convert_tokens_to_ids(LINE_SEP) |
| asep = tokenizer.convert_tokens_to_ids(ACTION_SEP) |
| spans, i, n = [], 0, len(trace_ids) |
| while i < n: |
| if trace_ids[i] == ls: |
| j = i + 1 |
| while j < n and trace_ids[j] != asep: |
| j += 1 |
| if j == n: |
| break |
| spans.append((i, j)) |
| i = j + 1 |
| else: |
| i += 1 |
| if not spans: |
| return None |
| return prompt_ids, trace_ids, spans |
|
|
|
|
| def build_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1): |
| """SFT ``(input_ids, labels)`` with the prompt masked; None to skip.""" |
| r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames) |
| if r is None: |
| return None |
| prompt_ids, trace_ids, _ = r |
| return prompt_ids + trace_ids, [IGNORE_INDEX] * len(prompt_ids) + trace_ids |
|
|
|
|
| def build_codi_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1): |
| """Multi-span CODI example ``{prompt_ids, trace_ids, spans}``; None to skip.""" |
| r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames) |
| if r is None: |
| return None |
| prompt_ids, trace_ids, spans = r |
| return {"prompt_ids": prompt_ids, "trace_ids": trace_ids, "spans": spans} |
|
|
|
|
| def _load_cache(cache_dir, n_samples): |
| """Load precomputed tokenized examples (precompute.py); slice to n_samples.""" |
| from datasets import load_from_disk |
|
|
| ex = list(load_from_disk(cache_dir)) |
| return ex[:n_samples] if n_samples > 0 else ex |
|
|
|
|
| def build_codi_dataset( |
| tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1, |
| max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None |
| ) -> list[dict]: |
| """CODI examples (prompt/reasoning/answer) over ``sources``, or a precomputed cache.""" |
| if cache_dir: |
| ex = _load_cache(cache_dir, n_samples) |
| return [e for e in ex if len(e["prompt_ids"]) + len(e["trace_ids"]) <= max_seq_len] |
| rows = rows_for_sources(sources) |
| if n_samples > 0: |
| rows = rows[:n_samples] |
| out = [] |
| for r in rows: |
| try: |
| out.append(build_codi_example(r["code"], r["input"], tokenizer, |
| max_seq_len=max_seq_len, max_frames=max_frames)) |
| except Exception: |
| pass |
| return [ex for ex in out if ex is not None] |
|
|
|
|
| def build_codi_single_dataset( |
| tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1, |
| max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None |
| ) -> list[dict]: |
| """Faithful single-block CODI: split each trace at its last ``<|return_sep|>`` into |
| ``{prompt_ids, reasoning_ids, answer_ids}`` (reasoning = whole trace, answer = final |
| RETURN frame). Derived from the multi-span examples; no separate cache needed.""" |
| rsep = tokenizer.convert_tokens_to_ids("<|return_sep|>") |
| out = [] |
| for e in build_codi_dataset(tokenizer, sources=sources, n_samples=n_samples, |
| max_seq_len=max_seq_len, max_frames=max_frames, cache_dir=cache_dir): |
| t = e["trace_ids"] |
| idx = [i for i, x in enumerate(t) if x == rsep] |
| if not idx or idx[-1] == 0: |
| continue |
| out.append({"prompt_ids": e["prompt_ids"], "reasoning_ids": t[:idx[-1]], "answer_ids": t[idx[-1]:]}) |
| return out |
|
|
|
|
| def rows_for_sources(sources): |
| """Merge {id,code,input,output} rows across sources (all rows; train vs test |
| is split by dataset, e.g. cruxeval is held out for eval).""" |
| from . import sources as _src |
|
|
| rows = [] |
| for name in sources: |
| for i, row in enumerate(_src.load_one(name)): |
| missing = [k for k in ("id", "code", "input", "output") if k not in row] |
| if missing: |
| raise ValueError(f"{name} row {i} missing keys: {missing}") |
| if not all(isinstance(row[k], str) for k in ("code", "input", "output")): |
| raise TypeError(f"{name} row {i} must use string code/input/output") |
| row = dict(row) |
| row["id"] = str(row["id"]) |
| rows.append(row) |
| return rows |
|
|
|
|
| def build_dataset( |
| tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1, |
| max_seq_len: int = 8192, max_frames: int = -1, cache_dir: str | None = None |
| ) -> list[tuple[list[int], list[int]]]: |
| """Tokenized trace examples over ``sources``, or a precomputed cache.""" |
| if cache_dir: |
| ex = _load_cache(cache_dir, n_samples) |
| return [(e["input_ids"], e["labels"]) for e in ex if len(e["input_ids"]) <= max_seq_len] |
| rows = rows_for_sources(sources) |
| if n_samples > 0: |
| rows = rows[:n_samples] |
| examples = ( |
| build_example( |
| r["code"], r["input"], tokenizer, |
| max_seq_len=max_seq_len, max_frames=max_frames, |
| ) |
| for r in rows |
| ) |
| return [ex for ex in examples if ex is not None] |
|
|