File size: 7,908 Bytes
aedd6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Shared CWM execution-trace representation and parsing.

CWM predicts an execution trace as a sequence of *frames*, each consisting of an
*observation* (the local-variable state) and an *action* (the executed source
line). The on-the-wire format (see PROMPTING_GUIDE.md and demos/cwmdbg.py) is:

    <|call_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
    <|line_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
    <|return_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
    <|exception_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>

`$LOCALS` is a JSON object mapping variable names to *string* values; each value
is the JSON encoding of the underlying Python value (e.g. `"5"`, `"\"abc\""`,
`"[1, 2]"`). Locals use a diff-based representation: a variable whose value is
unchanged since the previous frame in the same scope is rendered as the
placeholder string `".."`. `$VALUE` (return/exception frames) is the JSON
encoding of the returned/raised value, stored as a JSON string.

This module is GPU-free and import-light so it can be unit-tested directly.
"""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from enum import Enum

# Literal piece strings as they appear when a generation is decoded with
# cut_at_stop_tokens=False (matches CWMInstructTokenizer.*_ID constants).
CALL_SEP = "<|call_sep|>"
LINE_SEP = "<|line_sep|>"
RETURN_SEP = "<|return_sep|>"
EXCEPTION_SEP = "<|exception_sep|>"
ACTION_SEP = "<|action_sep|>"
ARG_SEP = "<|arg_sep|>"
FRAME_SEP = "<|frame_sep|>"
END_OF_TEXT = "<|end_of_text|>"

DIFF_PLACEHOLDER = ".."
_START_MARKER = "  # << START_OF_TRACE"


class TraceEvent(Enum):
    CALL = "call"
    LINE = "line"
    RETURN = "return"
    EXCEPTION = "exception"


_EVENT_TOKENS: dict[str, TraceEvent] = {
    CALL_SEP: TraceEvent.CALL,
    LINE_SEP: TraceEvent.LINE,
    RETURN_SEP: TraceEvent.RETURN,
    EXCEPTION_SEP: TraceEvent.EXCEPTION,
}
_EVENT_TO_TOKEN: dict[TraceEvent, str] = {v: k for k, v in _EVENT_TOKENS.items()}


@dataclass
class TraceFrame:
    """A single execution-trace frame.

    `locals_str` is the raw `$LOCALS` text exactly as it appears between the
    event token and `<|action_sep|>` (empty string for return/exception
    frames). `locals` is its parsed form (a dict of name -> JSON-string-value),
    or None if it failed to parse as a JSON object. `source` is the action line
    with the START_OF_TRACE marker and trailing newline stripped.
    """

    event: TraceEvent
    source: str
    locals_str: str = ""
    locals: dict[str, str] | None = None
    arg: str | None = None
    malformed: bool = False
    # Token counts (filled when a tokenizer is available); used for the
    # "Avg State/Action Length (Token)" statistics rows of Table 9.
    state_tokens: int = 0
    action_tokens: int = 0

    @property
    def has_locals(self) -> bool:
        return self.event in (TraceEvent.CALL, TraceEvent.LINE)


def normalize_source(source: str) -> str:
    """Strip the trace start marker and trailing newline from a source line."""
    return source.rstrip("\n").rstrip(_START_MARKER).rstrip()


def parse_locals(locals_str: str) -> dict[str, str] | None:
    """Parse a `$LOCALS` payload into a dict, or None if it is not a JSON object."""
    locals_str = locals_str.strip()
    if locals_str == "":
        return {}
    try:
        obj = json.loads(locals_str)
    except json.JSONDecodeError:
        return None
    if not isinstance(obj, dict):
        return None
    # Values are always JSON strings; coerce defensively.
    return {str(k): v if isinstance(v, str) else json.dumps(v) for k, v in obj.items()}


def parse_generated_trace(generation: str) -> tuple[list[TraceFrame], bool]:
    """Parse a full-trace generation string into frames.

    Returns (frames, well_formed). `well_formed` is True when every frame had a
    leading event token and an `<|action_sep|>` (and an `<|arg_sep|>` for
    return/exception frames) and the generation contained no leftover garbage
    between the last frame and end-of-text. This drives the "Valid Trace Format"
    metric. Individual frames are still returned even when malformed so that the
    other metrics can be computed over whatever parsed cleanly.
    """
    # Everything after end-of-text is irrelevant.
    if END_OF_TEXT in generation:
        generation = generation.split(END_OF_TEXT, 1)[0]

    frames: list[TraceFrame] = []
    well_formed = True
    segments = generation.split(FRAME_SEP)
    # The final segment is the text after the last frame_sep; for a clean trace
    # it should be empty (the model emitted frame_sep then end_of_text).
    trailing = segments.pop() if segments else ""
    if trailing.strip() not in ("",):
        well_formed = False

    for seg in segments:
        if seg.strip() == "":
            # Stray empty segment (e.g. leading text before first token).
            continue
        frame, ok = _parse_segment(seg)
        if frame is None:
            well_formed = False
            continue
        well_formed = well_formed and ok
        frames.append(frame)

    if not frames:
        well_formed = False

    return frames, well_formed


def _parse_segment(seg: str) -> tuple[TraceFrame | None, bool]:
    # Identify the (first) event token.
    event: TraceEvent | None = None
    for tok, evt in _EVENT_TOKENS.items():
        idx = seg.find(tok)
        if idx != -1:
            event = evt
            seg = seg[idx + len(tok):]
            break
    if event is None:
        return None, False

    ok = True
    if event in (TraceEvent.CALL, TraceEvent.LINE):
        if ACTION_SEP not in seg:
            return (
                TraceFrame(event=event, source="", malformed=True),
                False,
            )
        locals_str, source = seg.split(ACTION_SEP, 1)
        parsed = parse_locals(locals_str)
        return (
            TraceFrame(
                event=event,
                source=normalize_source(source),
                locals_str=locals_str.strip(),
                locals=parsed,
                malformed=parsed is None,
            ),
            ok,
        )

    # RETURN / EXCEPTION
    if ACTION_SEP not in seg:
        return TraceFrame(event=event, source="", malformed=True), False
    seg = seg.split(ACTION_SEP, 1)[1]
    if ARG_SEP in seg:
        source, arg = seg.split(ARG_SEP, 1)
        arg = _parse_arg(arg)
    else:
        source, arg = seg, None
        ok = False
    return (
        TraceFrame(event=event, source=normalize_source(source), arg=arg),
        ok,
    )


def render_frames_to_generation(frames: list[TraceFrame]) -> str:
    """Render frames back to the on-the-wire generation string.

    Inverse of ``parse_generated_trace`` for well-formed frames. Used by tests
    (a ground-truth trace rendered this way must round-trip to a perfect score)
    and to materialize a reference trace string for inspection.
    """
    out: list[str] = []
    for f in frames:
        out.append(_EVENT_TO_TOKEN[f.event])
        if f.has_locals:
            out.append(json.dumps(f.locals if f.locals is not None else {}))
        out.append(ACTION_SEP)
        out.append(f.source)
        if f.event in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
            out.append(ARG_SEP)
            out.append(json.dumps(f.arg))
        out.append(FRAME_SEP)
    out.append(END_OF_TEXT)
    return "".join(out)


def _parse_arg(arg_str: str) -> str | None:
    arg_str = arg_str.strip()
    if arg_str == "":
        return None
    try:
        # The frame stores json.dumps(value_string); unwrap one level so `arg`
        # is the source-literal value string (e.g. '"x9ja"' or '17').
        loaded = json.loads(arg_str)
        return loaded if isinstance(loaded, str) else arg_str
    except json.JSONDecodeError:
        return arg_str