| |
| import argparse |
| import json |
| import re |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import streamlit as st |
|
|
| ROLE_STYLE = { |
| "system": {"label": "SYSTEM", "color": "#4B5563", "bg": "#F3F4F6"}, |
| "user": {"label": "USER", "color": "#1D4ED8", "bg": "#DBEAFE"}, |
| "assistant": {"label": "ASSISTANT", "color": "#065F46", "bg": "#D1FAE5"}, |
| "tool": {"label": "TOOL", "color": "#7C2D12", "bg": "#FFEDD5"}, |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Streamlit viewer for eval-agent trajectory files.") |
| parser.add_argument("--dir", type=str, default="", help="Directory containing agent run files.") |
| return parser.parse_args() |
|
|
|
|
| def file_sort_key(path: Path) -> Tuple[int, int, str]: |
| m = re.match(r"gen_(\d+)_(.*)$", path.name) |
| if not m: |
| return (10**9, 10**9, path.name) |
| gen = int(m.group(1)) |
| suffix = m.group(2) |
| order = { |
| "task_message.txt": 0, |
| "result.json": 1, |
| "trajectory_messages.json": 2, |
| }.get(suffix, 99) |
| return (gen, order, path.name) |
|
|
|
|
| def try_load_json(path: Path) -> Optional[Any]: |
| try: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except Exception: |
| return None |
|
|
|
|
| def extract_text_from_message(message: Dict[str, Any]) -> str: |
| text_parts: List[str] = [] |
| content = message.get("content") |
| if isinstance(content, list): |
| for item in content: |
| if isinstance(item, dict) and item.get("type") == "text": |
| text = item.get("text") |
| if isinstance(text, str) and text: |
| text_parts.append(text) |
| return "\n".join(text_parts).strip() |
|
|
|
|
| def trajectory_summary(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| rows: List[Dict[str, Any]] = [] |
| for idx, msg in enumerate(messages): |
| role = msg.get("role", "unknown") |
| text = extract_text_from_message(msg) |
| preview = text[:120] + ("..." if len(text) > 120 else "") |
| tool_calls = msg.get("tool_calls") |
| tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0 |
| rows.append( |
| { |
| "idx": idx, |
| "role": role, |
| "tool_calls": tool_call_count, |
| "chars": len(text), |
| "preview": preview, |
| } |
| ) |
| return rows |
|
|
|
|
| def render_trajectory(messages: List[Dict[str, Any]]): |
| st.subheader("Trajectory Overview") |
| rows = trajectory_summary(messages) |
| st.dataframe(rows, width="stretch") |
|
|
| st.subheader("Full Message Timeline") |
| show_raw = st.checkbox("Show raw dict under each message", value=False) |
| for idx, msg in enumerate(messages): |
| role = str(msg.get("role", "unknown")) |
| style = ROLE_STYLE.get(role, {"label": role.upper(), "color": "#111827", "bg": "#F9FAFB"}) |
| text = extract_text_from_message(msg) |
| tool_calls = msg.get("tool_calls") |
| tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0 |
| title = f"{style['label']} #{idx}" |
| if tool_call_count > 0: |
| title += f" | tool_calls={tool_call_count}" |
|
|
| st.markdown( |
| ( |
| f"<div style='margin:8px 0 4px 0;'>" |
| f"<span style='background:{style['bg']}; color:{style['color']};" |
| " padding:4px 10px; border-radius:999px; font-weight:700;'>" |
| f"{title}</span></div>" |
| ), |
| unsafe_allow_html=True, |
| ) |
| show_msg = st.toggle(f"Show message #{idx}", value=True, key=f"show_msg_{idx}") |
| if show_msg: |
| if text: |
| st.markdown( |
| ( |
| f"<div style='border-left:4px solid {style['color']}; padding:8px 12px;" |
| f" background:{style['bg']}; border-radius:6px; white-space:pre-wrap;'>" |
| f"{text}</div>" |
| ), |
| unsafe_allow_html=True, |
| ) |
| else: |
| st.caption("<no text content>") |
|
|
| if show_raw and show_msg: |
| st.json(msg) |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| st.set_page_config(page_title="Eval Agent Trajectory Viewer", layout="wide") |
| st.title("Eval Agent Trajectory Viewer") |
|
|
| default_dir = args.dir or "" |
| run_dir_input = st.sidebar.text_input("Run directory", value=default_dir) |
| run_dir = Path(run_dir_input).expanduser() if run_dir_input else None |
|
|
| if not run_dir_input: |
| st.info("Pass `--dir` or set the directory in the sidebar.") |
| return |
|
|
| if not run_dir or not run_dir.exists() or not run_dir.is_dir(): |
| st.error(f"Directory not found: {run_dir_input}") |
| return |
|
|
| files = sorted([p for p in run_dir.iterdir() if p.is_file()], key=file_sort_key) |
| if not files: |
| st.warning("No files found in this directory.") |
| return |
|
|
| file_names = [p.name for p in files] |
| selected_name = st.sidebar.selectbox("Select file", options=file_names, index=0) |
| selected_path = run_dir / selected_name |
|
|
| st.caption(f"Selected: `{selected_path}`") |
| st.caption(f"Size: {selected_path.stat().st_size:,} bytes") |
|
|
| if selected_name.endswith("_trajectory_messages.json"): |
| data = try_load_json(selected_path) |
| if not isinstance(data, list): |
| st.error("Trajectory file is not a JSON list.") |
| return |
| msg_list = [x for x in data if isinstance(x, dict)] |
| st.success(f"Loaded {len(msg_list)} message dicts.") |
| render_trajectory(msg_list) |
| elif selected_name.endswith(".json"): |
| data = try_load_json(selected_path) |
| if data is None: |
| st.error("Failed to parse JSON.") |
| raw = selected_path.read_text(encoding="utf-8", errors="replace") |
| st.code(raw, language="json") |
| else: |
| st.json(data) |
| else: |
| raw = selected_path.read_text(encoding="utf-8", errors="replace") |
| st.code(raw, language="text") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|