File size: 7,908 Bytes
aedd6ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | # Copyright (c) Meta Platforms, Inc. and affiliates.
"""
Shared CWM execution-trace representation and parsing.
CWM predicts an execution trace as a sequence of *frames*, each consisting of an
*observation* (the local-variable state) and an *action* (the executed source
line). The on-the-wire format (see PROMPTING_GUIDE.md and demos/cwmdbg.py) is:
<|call_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
<|line_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
<|return_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
<|exception_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
`$LOCALS` is a JSON object mapping variable names to *string* values; each value
is the JSON encoding of the underlying Python value (e.g. `"5"`, `"\"abc\""`,
`"[1, 2]"`). Locals use a diff-based representation: a variable whose value is
unchanged since the previous frame in the same scope is rendered as the
placeholder string `".."`. `$VALUE` (return/exception frames) is the JSON
encoding of the returned/raised value, stored as a JSON string.
This module is GPU-free and import-light so it can be unit-tested directly.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from enum import Enum
# Literal piece strings as they appear when a generation is decoded with
# cut_at_stop_tokens=False (matches CWMInstructTokenizer.*_ID constants).
CALL_SEP = "<|call_sep|>"
LINE_SEP = "<|line_sep|>"
RETURN_SEP = "<|return_sep|>"
EXCEPTION_SEP = "<|exception_sep|>"
ACTION_SEP = "<|action_sep|>"
ARG_SEP = "<|arg_sep|>"
FRAME_SEP = "<|frame_sep|>"
END_OF_TEXT = "<|end_of_text|>"
DIFF_PLACEHOLDER = ".."
_START_MARKER = " # << START_OF_TRACE"
class TraceEvent(Enum):
CALL = "call"
LINE = "line"
RETURN = "return"
EXCEPTION = "exception"
_EVENT_TOKENS: dict[str, TraceEvent] = {
CALL_SEP: TraceEvent.CALL,
LINE_SEP: TraceEvent.LINE,
RETURN_SEP: TraceEvent.RETURN,
EXCEPTION_SEP: TraceEvent.EXCEPTION,
}
_EVENT_TO_TOKEN: dict[TraceEvent, str] = {v: k for k, v in _EVENT_TOKENS.items()}
@dataclass
class TraceFrame:
"""A single execution-trace frame.
`locals_str` is the raw `$LOCALS` text exactly as it appears between the
event token and `<|action_sep|>` (empty string for return/exception
frames). `locals` is its parsed form (a dict of name -> JSON-string-value),
or None if it failed to parse as a JSON object. `source` is the action line
with the START_OF_TRACE marker and trailing newline stripped.
"""
event: TraceEvent
source: str
locals_str: str = ""
locals: dict[str, str] | None = None
arg: str | None = None
malformed: bool = False
# Token counts (filled when a tokenizer is available); used for the
# "Avg State/Action Length (Token)" statistics rows of Table 9.
state_tokens: int = 0
action_tokens: int = 0
@property
def has_locals(self) -> bool:
return self.event in (TraceEvent.CALL, TraceEvent.LINE)
def normalize_source(source: str) -> str:
"""Strip the trace start marker and trailing newline from a source line."""
return source.rstrip("\n").rstrip(_START_MARKER).rstrip()
def parse_locals(locals_str: str) -> dict[str, str] | None:
"""Parse a `$LOCALS` payload into a dict, or None if it is not a JSON object."""
locals_str = locals_str.strip()
if locals_str == "":
return {}
try:
obj = json.loads(locals_str)
except json.JSONDecodeError:
return None
if not isinstance(obj, dict):
return None
# Values are always JSON strings; coerce defensively.
return {str(k): v if isinstance(v, str) else json.dumps(v) for k, v in obj.items()}
def parse_generated_trace(generation: str) -> tuple[list[TraceFrame], bool]:
"""Parse a full-trace generation string into frames.
Returns (frames, well_formed). `well_formed` is True when every frame had a
leading event token and an `<|action_sep|>` (and an `<|arg_sep|>` for
return/exception frames) and the generation contained no leftover garbage
between the last frame and end-of-text. This drives the "Valid Trace Format"
metric. Individual frames are still returned even when malformed so that the
other metrics can be computed over whatever parsed cleanly.
"""
# Everything after end-of-text is irrelevant.
if END_OF_TEXT in generation:
generation = generation.split(END_OF_TEXT, 1)[0]
frames: list[TraceFrame] = []
well_formed = True
segments = generation.split(FRAME_SEP)
# The final segment is the text after the last frame_sep; for a clean trace
# it should be empty (the model emitted frame_sep then end_of_text).
trailing = segments.pop() if segments else ""
if trailing.strip() not in ("",):
well_formed = False
for seg in segments:
if seg.strip() == "":
# Stray empty segment (e.g. leading text before first token).
continue
frame, ok = _parse_segment(seg)
if frame is None:
well_formed = False
continue
well_formed = well_formed and ok
frames.append(frame)
if not frames:
well_formed = False
return frames, well_formed
def _parse_segment(seg: str) -> tuple[TraceFrame | None, bool]:
# Identify the (first) event token.
event: TraceEvent | None = None
for tok, evt in _EVENT_TOKENS.items():
idx = seg.find(tok)
if idx != -1:
event = evt
seg = seg[idx + len(tok):]
break
if event is None:
return None, False
ok = True
if event in (TraceEvent.CALL, TraceEvent.LINE):
if ACTION_SEP not in seg:
return (
TraceFrame(event=event, source="", malformed=True),
False,
)
locals_str, source = seg.split(ACTION_SEP, 1)
parsed = parse_locals(locals_str)
return (
TraceFrame(
event=event,
source=normalize_source(source),
locals_str=locals_str.strip(),
locals=parsed,
malformed=parsed is None,
),
ok,
)
# RETURN / EXCEPTION
if ACTION_SEP not in seg:
return TraceFrame(event=event, source="", malformed=True), False
seg = seg.split(ACTION_SEP, 1)[1]
if ARG_SEP in seg:
source, arg = seg.split(ARG_SEP, 1)
arg = _parse_arg(arg)
else:
source, arg = seg, None
ok = False
return (
TraceFrame(event=event, source=normalize_source(source), arg=arg),
ok,
)
def render_frames_to_generation(frames: list[TraceFrame]) -> str:
"""Render frames back to the on-the-wire generation string.
Inverse of ``parse_generated_trace`` for well-formed frames. Used by tests
(a ground-truth trace rendered this way must round-trip to a perfect score)
and to materialize a reference trace string for inspection.
"""
out: list[str] = []
for f in frames:
out.append(_EVENT_TO_TOKEN[f.event])
if f.has_locals:
out.append(json.dumps(f.locals if f.locals is not None else {}))
out.append(ACTION_SEP)
out.append(f.source)
if f.event in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
out.append(ARG_SEP)
out.append(json.dumps(f.arg))
out.append(FRAME_SEP)
out.append(END_OF_TEXT)
return "".join(out)
def _parse_arg(arg_str: str) -> str | None:
arg_str = arg_str.strip()
if arg_str == "":
return None
try:
# The frame stores json.dumps(value_string); unwrap one level so `arg`
# is the source-literal value string (e.g. '"x9ja"' or '17').
loaded = json.loads(arg_str)
return loaded if isinstance(loaded, str) else arg_str
except json.JSONDecodeError:
return arg_str
|