| """ |
| classifier.py β two LLM calls for the two-stage flow. |
| |
| extract_schema(study_context, preview) β SchemaSummary (no label) |
| pick_label (schema, study_context) β (label, reasoning) |
| |
| The orchestrator (statlens_run.py) drives both. Each call is independent β |
| the user can edit the schema between them, and the label decision uses the |
| edited schema. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import re |
| from dataclasses import dataclass, field |
|
|
| import httpx |
|
|
| from .prompts import ( |
| LABELS, |
| build_extract_messages, |
| build_label_messages, |
| ) |
| from .schema_spec import coerce_and_validate, default_schema |
|
|
|
|
| @dataclass |
| class ExtractResult: |
| """Output of stage 1: just the schema, no label.""" |
| schema: dict |
| raw: str |
| schema_warnings: list[str] = field(default_factory=list) |
|
|
|
|
| @dataclass |
| class LabelResult: |
| """Output of stage 3: a label picked from a (confirmed) schema.""" |
| label: str |
| reasoning: str |
| raw: str |
| valid: bool |
|
|
|
|
| def _extract_json_object(text: str) -> dict | None: |
| """Find the largest balanced JSON object in a blob of text.""" |
| try: |
| return json.loads(text.strip()) |
| except Exception: |
| pass |
| starts = [i for i, c in enumerate(text) if c == "{"] |
| for s in starts: |
| depth = 0 |
| for i in range(s, len(text)): |
| if text[i] == "{": |
| depth += 1 |
| elif text[i] == "}": |
| depth -= 1 |
| if depth == 0: |
| cand = text[s:i+1] |
| try: |
| return json.loads(cand) |
| except Exception: |
| break |
| matches = re.findall(r"\{[^{}]*\}", text, re.DOTALL) |
| for m in sorted(matches, key=len, reverse=True): |
| try: |
| return json.loads(m) |
| except Exception: |
| continue |
| return None |
|
|
|
|
| def _post_chat(endpoint: str, messages: list[dict], model: str, |
| api_key: str, timeout: float, max_tokens: int) -> str: |
| url = endpoint.rstrip("/") + "/chat/completions" |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
| body = {"model": model, "messages": messages, |
| "temperature": 0.0, "max_tokens": max_tokens} |
| with httpx.Client(timeout=timeout) as client: |
| r = client.post(url, json=body, headers=headers) |
| r.raise_for_status() |
| data = r.json() |
| return data["choices"][0]["message"]["content"] |
|
|
|
|
| |
| |
| |
| def extract_schema( |
| study_context: str, |
| preview: str, |
| endpoint: str, |
| model: str = "statlens", |
| api_key: str = "dummy", |
| timeout: float = 120.0, |
| ) -> ExtractResult: |
| """Call LLM, parse the structured SchemaSummary. No label here.""" |
| messages = build_extract_messages(study_context, preview) |
| raw = _post_chat(endpoint, messages, model, api_key, timeout, max_tokens=1200) |
|
|
| obj = _extract_json_object(raw) |
| if obj is None: |
| return ExtractResult( |
| schema=default_schema(), raw=raw, |
| schema_warnings=["LLM output was not valid JSON; using defaults"], |
| ) |
|
|
| |
| if isinstance(obj.get("schema"), dict): |
| obj = obj["schema"] |
|
|
| schema, warns = coerce_and_validate(obj) |
| return ExtractResult(schema=schema, raw=raw, schema_warnings=warns) |
|
|
|
|
| |
| |
| |
| def pick_label( |
| schema: dict, |
| study_context: str, |
| endpoint: str, |
| model: str = "statlens", |
| api_key: str = "dummy", |
| timeout: float = 60.0, |
| ) -> LabelResult: |
| """Call LLM with the schema rendered as natural-language bullets.""" |
| messages = build_label_messages(study_context, schema) |
| raw = _post_chat(endpoint, messages, model, api_key, timeout, max_tokens=400) |
|
|
| obj = _extract_json_object(raw) |
| if obj is None: |
| return LabelResult( |
| label="none_of_these", |
| reasoning="(failed to parse JSON from model)", |
| raw=raw, valid=False, |
| ) |
|
|
| label = str(obj.get("label", "")).strip() |
| reasoning = str(obj.get("reasoning", "")).strip() |
| valid = label in LABELS |
| if not valid: |
| |
| for cand in LABELS: |
| if cand.lower() == label.lower(): |
| label = cand |
| valid = True |
| break |
| return LabelResult(label=label, reasoning=reasoning, raw=raw, valid=valid) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| from pathlib import Path |
| from .raw_preview import build_raw_preview |
|
|
| ap = argparse.ArgumentParser() |
| ap.add_argument("--context", required=True) |
| ap.add_argument("--tsv", required=True) |
| ap.add_argument("--endpoint", required=True) |
| ap.add_argument("--model", default="statlens") |
| ap.add_argument("--api-key", default="dummy") |
| args = ap.parse_args() |
|
|
| ctx = Path(args.context).read_text() |
| preview = build_raw_preview(Path(args.tsv)) |
|
|
| print("=== Stage 1: extract schema ===") |
| er = extract_schema(ctx, preview, args.endpoint, args.model, args.api_key) |
| print(f" warnings : {er.schema_warnings}") |
| print(json.dumps(er.schema, indent=2)) |
|
|
| print("\n=== Stage 3: pick label (using LLM-extracted schema verbatim) ===") |
| lr = pick_label(er.schema, ctx, args.endpoint, args.model, args.api_key) |
| print(f" label : {lr.label} (valid={lr.valid})") |
| print(f" reasoning: {lr.reasoning}") |
|
|