| |
|
|
| """ |
| Shared CWM execution-trace representation and parsing. |
| |
| CWM predicts an execution trace as a sequence of *frames*, each consisting of an |
| *observation* (the local-variable state) and an *action* (the executed source |
| line). The on-the-wire format (see PROMPTING_GUIDE.md and demos/cwmdbg.py) is: |
| |
| <|call_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|> |
| <|line_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|> |
| <|return_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|> |
| <|exception_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|> |
| |
| `$LOCALS` is a JSON object mapping variable names to *string* values; each value |
| is the JSON encoding of the underlying Python value (e.g. `"5"`, `"\"abc\""`, |
| `"[1, 2]"`). Locals use a diff-based representation: a variable whose value is |
| unchanged since the previous frame in the same scope is rendered as the |
| placeholder string `".."`. `$VALUE` (return/exception frames) is the JSON |
| encoding of the returned/raised value, stored as a JSON string. |
| |
| This module is GPU-free and import-light so it can be unit-tested directly. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass, field |
| from enum import Enum |
|
|
| |
| |
| CALL_SEP = "<|call_sep|>" |
| LINE_SEP = "<|line_sep|>" |
| RETURN_SEP = "<|return_sep|>" |
| EXCEPTION_SEP = "<|exception_sep|>" |
| ACTION_SEP = "<|action_sep|>" |
| ARG_SEP = "<|arg_sep|>" |
| FRAME_SEP = "<|frame_sep|>" |
| END_OF_TEXT = "<|end_of_text|>" |
|
|
| DIFF_PLACEHOLDER = ".." |
| _START_MARKER = " # << START_OF_TRACE" |
|
|
|
|
| class TraceEvent(Enum): |
| CALL = "call" |
| LINE = "line" |
| RETURN = "return" |
| EXCEPTION = "exception" |
|
|
|
|
| _EVENT_TOKENS: dict[str, TraceEvent] = { |
| CALL_SEP: TraceEvent.CALL, |
| LINE_SEP: TraceEvent.LINE, |
| RETURN_SEP: TraceEvent.RETURN, |
| EXCEPTION_SEP: TraceEvent.EXCEPTION, |
| } |
| _EVENT_TO_TOKEN: dict[TraceEvent, str] = {v: k for k, v in _EVENT_TOKENS.items()} |
|
|
|
|
| @dataclass |
| class TraceFrame: |
| """A single execution-trace frame. |
| |
| `locals_str` is the raw `$LOCALS` text exactly as it appears between the |
| event token and `<|action_sep|>` (empty string for return/exception |
| frames). `locals` is its parsed form (a dict of name -> JSON-string-value), |
| or None if it failed to parse as a JSON object. `source` is the action line |
| with the START_OF_TRACE marker and trailing newline stripped. |
| """ |
|
|
| event: TraceEvent |
| source: str |
| locals_str: str = "" |
| locals: dict[str, str] | None = None |
| arg: str | None = None |
| malformed: bool = False |
| |
| |
| state_tokens: int = 0 |
| action_tokens: int = 0 |
|
|
| @property |
| def has_locals(self) -> bool: |
| return self.event in (TraceEvent.CALL, TraceEvent.LINE) |
|
|
|
|
| def normalize_source(source: str) -> str: |
| """Strip the trace start marker and trailing newline from a source line.""" |
| return source.rstrip("\n").rstrip(_START_MARKER).rstrip() |
|
|
|
|
| def parse_locals(locals_str: str) -> dict[str, str] | None: |
| """Parse a `$LOCALS` payload into a dict, or None if it is not a JSON object.""" |
| locals_str = locals_str.strip() |
| if locals_str == "": |
| return {} |
| try: |
| obj = json.loads(locals_str) |
| except json.JSONDecodeError: |
| return None |
| if not isinstance(obj, dict): |
| return None |
| |
| return {str(k): v if isinstance(v, str) else json.dumps(v) for k, v in obj.items()} |
|
|
|
|
| def parse_generated_trace(generation: str) -> tuple[list[TraceFrame], bool]: |
| """Parse a full-trace generation string into frames. |
| |
| Returns (frames, well_formed). `well_formed` is True when every frame had a |
| leading event token and an `<|action_sep|>` (and an `<|arg_sep|>` for |
| return/exception frames) and the generation contained no leftover garbage |
| between the last frame and end-of-text. This drives the "Valid Trace Format" |
| metric. Individual frames are still returned even when malformed so that the |
| other metrics can be computed over whatever parsed cleanly. |
| """ |
| |
| if END_OF_TEXT in generation: |
| generation = generation.split(END_OF_TEXT, 1)[0] |
|
|
| frames: list[TraceFrame] = [] |
| well_formed = True |
| segments = generation.split(FRAME_SEP) |
| |
| |
| trailing = segments.pop() if segments else "" |
| if trailing.strip() not in ("",): |
| well_formed = False |
|
|
| for seg in segments: |
| if seg.strip() == "": |
| |
| continue |
| frame, ok = _parse_segment(seg) |
| if frame is None: |
| well_formed = False |
| continue |
| well_formed = well_formed and ok |
| frames.append(frame) |
|
|
| if not frames: |
| well_formed = False |
|
|
| return frames, well_formed |
|
|
|
|
| def _parse_segment(seg: str) -> tuple[TraceFrame | None, bool]: |
| |
| event: TraceEvent | None = None |
| for tok, evt in _EVENT_TOKENS.items(): |
| idx = seg.find(tok) |
| if idx != -1: |
| event = evt |
| seg = seg[idx + len(tok):] |
| break |
| if event is None: |
| return None, False |
|
|
| ok = True |
| if event in (TraceEvent.CALL, TraceEvent.LINE): |
| if ACTION_SEP not in seg: |
| return ( |
| TraceFrame(event=event, source="", malformed=True), |
| False, |
| ) |
| locals_str, source = seg.split(ACTION_SEP, 1) |
| parsed = parse_locals(locals_str) |
| return ( |
| TraceFrame( |
| event=event, |
| source=normalize_source(source), |
| locals_str=locals_str.strip(), |
| locals=parsed, |
| malformed=parsed is None, |
| ), |
| ok, |
| ) |
|
|
| |
| if ACTION_SEP not in seg: |
| return TraceFrame(event=event, source="", malformed=True), False |
| seg = seg.split(ACTION_SEP, 1)[1] |
| if ARG_SEP in seg: |
| source, arg = seg.split(ARG_SEP, 1) |
| arg = _parse_arg(arg) |
| else: |
| source, arg = seg, None |
| ok = False |
| return ( |
| TraceFrame(event=event, source=normalize_source(source), arg=arg), |
| ok, |
| ) |
|
|
|
|
| def render_frames_to_generation(frames: list[TraceFrame]) -> str: |
| """Render frames back to the on-the-wire generation string. |
| |
| Inverse of ``parse_generated_trace`` for well-formed frames. Used by tests |
| (a ground-truth trace rendered this way must round-trip to a perfect score) |
| and to materialize a reference trace string for inspection. |
| """ |
| out: list[str] = [] |
| for f in frames: |
| out.append(_EVENT_TO_TOKEN[f.event]) |
| if f.has_locals: |
| out.append(json.dumps(f.locals if f.locals is not None else {})) |
| out.append(ACTION_SEP) |
| out.append(f.source) |
| if f.event in (TraceEvent.RETURN, TraceEvent.EXCEPTION): |
| out.append(ARG_SEP) |
| out.append(json.dumps(f.arg)) |
| out.append(FRAME_SEP) |
| out.append(END_OF_TEXT) |
| return "".join(out) |
|
|
|
|
| def _parse_arg(arg_str: str) -> str | None: |
| arg_str = arg_str.strip() |
| if arg_str == "": |
| return None |
| try: |
| |
| |
| loaded = json.loads(arg_str) |
| return loaded if isinstance(loaded, str) else arg_str |
| except json.JSONDecodeError: |
| return arg_str |
|
|