codi-trace / code /data /trace_format.py
sirui6011's picture
add code/ loader snapshot
aedd6ab verified
Raw
History Blame Contribute Delete
7.91 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
"""
Shared CWM execution-trace representation and parsing.
CWM predicts an execution trace as a sequence of *frames*, each consisting of an
*observation* (the local-variable state) and an *action* (the executed source
line). The on-the-wire format (see PROMPTING_GUIDE.md and demos/cwmdbg.py) is:
<|call_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
<|line_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
<|return_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
<|exception_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
`$LOCALS` is a JSON object mapping variable names to *string* values; each value
is the JSON encoding of the underlying Python value (e.g. `"5"`, `"\"abc\""`,
`"[1, 2]"`). Locals use a diff-based representation: a variable whose value is
unchanged since the previous frame in the same scope is rendered as the
placeholder string `".."`. `$VALUE` (return/exception frames) is the JSON
encoding of the returned/raised value, stored as a JSON string.
This module is GPU-free and import-light so it can be unit-tested directly.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from enum import Enum
# Literal piece strings as they appear when a generation is decoded with
# cut_at_stop_tokens=False (matches CWMInstructTokenizer.*_ID constants).
CALL_SEP = "<|call_sep|>"
LINE_SEP = "<|line_sep|>"
RETURN_SEP = "<|return_sep|>"
EXCEPTION_SEP = "<|exception_sep|>"
ACTION_SEP = "<|action_sep|>"
ARG_SEP = "<|arg_sep|>"
FRAME_SEP = "<|frame_sep|>"
END_OF_TEXT = "<|end_of_text|>"
DIFF_PLACEHOLDER = ".."
_START_MARKER = " # << START_OF_TRACE"
class TraceEvent(Enum):
CALL = "call"
LINE = "line"
RETURN = "return"
EXCEPTION = "exception"
_EVENT_TOKENS: dict[str, TraceEvent] = {
CALL_SEP: TraceEvent.CALL,
LINE_SEP: TraceEvent.LINE,
RETURN_SEP: TraceEvent.RETURN,
EXCEPTION_SEP: TraceEvent.EXCEPTION,
}
_EVENT_TO_TOKEN: dict[TraceEvent, str] = {v: k for k, v in _EVENT_TOKENS.items()}
@dataclass
class TraceFrame:
"""A single execution-trace frame.
`locals_str` is the raw `$LOCALS` text exactly as it appears between the
event token and `<|action_sep|>` (empty string for return/exception
frames). `locals` is its parsed form (a dict of name -> JSON-string-value),
or None if it failed to parse as a JSON object. `source` is the action line
with the START_OF_TRACE marker and trailing newline stripped.
"""
event: TraceEvent
source: str
locals_str: str = ""
locals: dict[str, str] | None = None
arg: str | None = None
malformed: bool = False
# Token counts (filled when a tokenizer is available); used for the
# "Avg State/Action Length (Token)" statistics rows of Table 9.
state_tokens: int = 0
action_tokens: int = 0
@property
def has_locals(self) -> bool:
return self.event in (TraceEvent.CALL, TraceEvent.LINE)
def normalize_source(source: str) -> str:
"""Strip the trace start marker and trailing newline from a source line."""
return source.rstrip("\n").rstrip(_START_MARKER).rstrip()
def parse_locals(locals_str: str) -> dict[str, str] | None:
"""Parse a `$LOCALS` payload into a dict, or None if it is not a JSON object."""
locals_str = locals_str.strip()
if locals_str == "":
return {}
try:
obj = json.loads(locals_str)
except json.JSONDecodeError:
return None
if not isinstance(obj, dict):
return None
# Values are always JSON strings; coerce defensively.
return {str(k): v if isinstance(v, str) else json.dumps(v) for k, v in obj.items()}
def parse_generated_trace(generation: str) -> tuple[list[TraceFrame], bool]:
"""Parse a full-trace generation string into frames.
Returns (frames, well_formed). `well_formed` is True when every frame had a
leading event token and an `<|action_sep|>` (and an `<|arg_sep|>` for
return/exception frames) and the generation contained no leftover garbage
between the last frame and end-of-text. This drives the "Valid Trace Format"
metric. Individual frames are still returned even when malformed so that the
other metrics can be computed over whatever parsed cleanly.
"""
# Everything after end-of-text is irrelevant.
if END_OF_TEXT in generation:
generation = generation.split(END_OF_TEXT, 1)[0]
frames: list[TraceFrame] = []
well_formed = True
segments = generation.split(FRAME_SEP)
# The final segment is the text after the last frame_sep; for a clean trace
# it should be empty (the model emitted frame_sep then end_of_text).
trailing = segments.pop() if segments else ""
if trailing.strip() not in ("",):
well_formed = False
for seg in segments:
if seg.strip() == "":
# Stray empty segment (e.g. leading text before first token).
continue
frame, ok = _parse_segment(seg)
if frame is None:
well_formed = False
continue
well_formed = well_formed and ok
frames.append(frame)
if not frames:
well_formed = False
return frames, well_formed
def _parse_segment(seg: str) -> tuple[TraceFrame | None, bool]:
# Identify the (first) event token.
event: TraceEvent | None = None
for tok, evt in _EVENT_TOKENS.items():
idx = seg.find(tok)
if idx != -1:
event = evt
seg = seg[idx + len(tok):]
break
if event is None:
return None, False
ok = True
if event in (TraceEvent.CALL, TraceEvent.LINE):
if ACTION_SEP not in seg:
return (
TraceFrame(event=event, source="", malformed=True),
False,
)
locals_str, source = seg.split(ACTION_SEP, 1)
parsed = parse_locals(locals_str)
return (
TraceFrame(
event=event,
source=normalize_source(source),
locals_str=locals_str.strip(),
locals=parsed,
malformed=parsed is None,
),
ok,
)
# RETURN / EXCEPTION
if ACTION_SEP not in seg:
return TraceFrame(event=event, source="", malformed=True), False
seg = seg.split(ACTION_SEP, 1)[1]
if ARG_SEP in seg:
source, arg = seg.split(ARG_SEP, 1)
arg = _parse_arg(arg)
else:
source, arg = seg, None
ok = False
return (
TraceFrame(event=event, source=normalize_source(source), arg=arg),
ok,
)
def render_frames_to_generation(frames: list[TraceFrame]) -> str:
"""Render frames back to the on-the-wire generation string.
Inverse of ``parse_generated_trace`` for well-formed frames. Used by tests
(a ground-truth trace rendered this way must round-trip to a perfect score)
and to materialize a reference trace string for inspection.
"""
out: list[str] = []
for f in frames:
out.append(_EVENT_TO_TOKEN[f.event])
if f.has_locals:
out.append(json.dumps(f.locals if f.locals is not None else {}))
out.append(ACTION_SEP)
out.append(f.source)
if f.event in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
out.append(ARG_SEP)
out.append(json.dumps(f.arg))
out.append(FRAME_SEP)
out.append(END_OF_TEXT)
return "".join(out)
def _parse_arg(arg_str: str) -> str | None:
arg_str = arg_str.strip()
if arg_str == "":
return None
try:
# The frame stores json.dumps(value_string); unwrap one level so `arg`
# is the source-literal value string (e.g. '"x9ja"' or '17').
loaded = json.loads(arg_str)
return loaded if isinstance(loaded, str) else arg_str
except json.JSONDecodeError:
return arg_str