Omkar Sreekanth commited on
Commit
6ea5be4
Β·
1 Parent(s): 354cce5

Add FastAPI backend wrapping ML inference pipeline

Browse files
Files changed (5) hide show
  1. app/__init__.py +0 -0
  2. app/api/__init__.py +0 -0
  3. app/api/routes.py +288 -0
  4. app/main.py +84 -0
  5. app/schemas.py +123 -0
app/__init__.py ADDED
File without changes
app/api/__init__.py ADDED
File without changes
app/api/routes.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI routes for Agent Trace Anomaly Detection.
3
+
4
+ All ML logic lives in scripts/inference.py (partner's code).
5
+ This file only handles HTTP ↔ inference translation.
6
+
7
+ Endpoints
8
+ ---------
9
+ GET /health – service health & loaded model info
10
+ POST /models/load – load a model (xgboost or distilbert)
11
+ POST /predict – predict anomaly for a single trace
12
+ POST /predict/batch – predict anomalies for multiple traces
13
+ POST /predict/compare – run both models on same trace, compare
14
+ POST /pipeline/train – trigger the full training pipeline
15
+ GET /pipeline/status – check if models dir has trained models
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import os
22
+ import subprocess
23
+ import logging
24
+ from typing import Any
25
+
26
+ from fastapi import APIRouter, HTTPException, Query
27
+
28
+ from app.schemas import (
29
+ PredictRequest,
30
+ PredictBatchRequest,
31
+ PredictResponse,
32
+ PredictBatchResponse,
33
+ CompareResponse,
34
+ ModelLoadRequest,
35
+ HealthResponse,
36
+ )
37
+
38
+ logger = logging.getLogger(__name__)
39
+ router = APIRouter()
40
+
41
+
42
+ # ── In-memory state ──────────────────────────────────────────────────────────
43
+
44
+ _state: dict[str, Any] = {
45
+ "detector": None, # TraceAnomalyDetector instance
46
+ "model_type": None, # "xgboost" or "distilbert"
47
+ "model_dir": "models", # path to saved models
48
+ }
49
+
50
+
51
+ def _get_detector():
52
+ """Return the loaded detector or raise 409."""
53
+ if _state["detector"] is None:
54
+ raise HTTPException(
55
+ status_code=409,
56
+ detail="No model loaded. POST /models/load first.",
57
+ )
58
+ return _state["detector"]
59
+
60
+
61
+ def _check_available_models(model_dir: str) -> list[str]:
62
+ """Check which trained models exist on disk."""
63
+ available = []
64
+ if os.path.exists(os.path.join(model_dir, "xgboost_model.joblib")):
65
+ available.append("xgboost")
66
+ if os.path.exists(os.path.join(model_dir, "distilbert_trace", "trace_config.json")):
67
+ available.append("distilbert")
68
+ if os.path.exists(os.path.join(model_dir, "naive_baseline.joblib")):
69
+ available.append("naive_baseline")
70
+ return available
71
+
72
+
73
+ # ── Health ───────────────────────────────────────────────────────────────────
74
+
75
+ @router.get("/health", response_model=HealthResponse, tags=["System"])
76
+ def health_check():
77
+ model_dir = _state["model_dir"]
78
+ return HealthResponse(
79
+ status="ok",
80
+ loaded_model=_state["model_type"],
81
+ available_models=_check_available_models(model_dir),
82
+ model_dir=model_dir,
83
+ )
84
+
85
+
86
+ # ── Model Loading ───────────────────────────────────────────────────────────
87
+
88
+ @router.post("/models/load", tags=["Models"])
89
+ def load_model(req: ModelLoadRequest):
90
+ """
91
+ Load a trained model into memory for inference.
92
+ Models must be trained first via the pipeline (setup.py).
93
+ """
94
+ from scripts.inference import TraceAnomalyDetector
95
+
96
+ available = _check_available_models(req.model_dir)
97
+ if req.model_type not in available:
98
+ raise HTTPException(
99
+ status_code=404,
100
+ detail=(
101
+ f"Model '{req.model_type}' not found in '{req.model_dir}/'. "
102
+ f"Available models: {available}. "
103
+ f"Run the training pipeline first: python setup.py"
104
+ ),
105
+ )
106
+
107
+ try:
108
+ detector = TraceAnomalyDetector(
109
+ model_dir=req.model_dir,
110
+ model_type=req.model_type,
111
+ )
112
+ _state["detector"] = detector
113
+ _state["model_type"] = req.model_type
114
+ _state["model_dir"] = req.model_dir
115
+
116
+ return {
117
+ "message": f"Model '{req.model_type}' loaded successfully",
118
+ "model_type": req.model_type,
119
+ "model_dir": req.model_dir,
120
+ }
121
+ except Exception as exc:
122
+ logger.exception("Failed to load model")
123
+ raise HTTPException(status_code=500, detail=f"Failed to load model: {exc}")
124
+
125
+
126
+ # ── Prediction ───────────────────────────────────────────────────────────────
127
+
128
+ @router.post("/predict", response_model=PredictResponse, tags=["Prediction"])
129
+ def predict(req: PredictRequest):
130
+ """
131
+ Predict whether a single agent trace is anomalous.
132
+
133
+ Pass conversations in ShareGPT/ToolBench format:
134
+ [{"from": "user", "value": "..."}, {"from": "assistant", "value": "..."}, ...]
135
+ """
136
+ detector = _get_detector()
137
+
138
+ try:
139
+ result = detector.predict(req.conversations)
140
+ except Exception as exc:
141
+ logger.exception("Prediction failed")
142
+ raise HTTPException(status_code=500, detail=f"Prediction error: {exc}")
143
+
144
+ return PredictResponse(
145
+ is_anomalous=result["is_anomalous"],
146
+ confidence=result["confidence"],
147
+ label=result["label"],
148
+ anomaly_signals=result.get("anomaly_signals", []),
149
+ model_used=_state["model_type"],
150
+ features=result.get("features"),
151
+ )
152
+
153
+
154
+ @router.post("/predict/batch", response_model=PredictBatchResponse, tags=["Prediction"])
155
+ def predict_batch(req: PredictBatchRequest):
156
+ """Predict anomalies for multiple traces at once."""
157
+ detector = _get_detector()
158
+
159
+ try:
160
+ results = detector.predict_batch(req.traces)
161
+ except Exception as exc:
162
+ logger.exception("Batch prediction failed")
163
+ raise HTTPException(status_code=500, detail=f"Batch prediction error: {exc}")
164
+
165
+ predictions = [
166
+ PredictResponse(
167
+ is_anomalous=r["is_anomalous"],
168
+ confidence=r["confidence"],
169
+ label=r["label"],
170
+ anomaly_signals=r.get("anomaly_signals", []),
171
+ model_used=_state["model_type"],
172
+ features=r.get("features"),
173
+ )
174
+ for r in results
175
+ ]
176
+
177
+ return PredictBatchResponse(
178
+ predictions=predictions,
179
+ anomaly_count=sum(1 for p in predictions if p.is_anomalous),
180
+ total=len(predictions),
181
+ )
182
+
183
+
184
+ @router.post("/predict/compare", response_model=CompareResponse, tags=["Prediction"])
185
+ def predict_compare(req: PredictRequest):
186
+ """
187
+ Run both XGBoost and DistilBERT on the same trace and compare results.
188
+ Both models must be trained and available in the models directory.
189
+ """
190
+ from scripts.inference import TraceAnomalyDetector
191
+
192
+ model_dir = _state["model_dir"]
193
+ available = _check_available_models(model_dir)
194
+ results = {}
195
+
196
+ for model_type in ["xgboost", "distilbert"]:
197
+ if model_type in available:
198
+ try:
199
+ det = TraceAnomalyDetector(model_dir=model_dir, model_type=model_type)
200
+ r = det.predict(req.conversations)
201
+ results[model_type] = PredictResponse(
202
+ is_anomalous=r["is_anomalous"],
203
+ confidence=r["confidence"],
204
+ label=r["label"],
205
+ anomaly_signals=r.get("anomaly_signals", []),
206
+ model_used=model_type,
207
+ features=r.get("features"),
208
+ )
209
+ except Exception as exc:
210
+ logger.warning("Compare: %s failed: %s", model_type, exc)
211
+
212
+ if not results:
213
+ raise HTTPException(
214
+ status_code=404,
215
+ detail=f"No trained models found in '{model_dir}/'. Run the training pipeline first.",
216
+ )
217
+
218
+ xgb = results.get("xgboost")
219
+ bert = results.get("distilbert")
220
+ agreement = True
221
+ if xgb and bert:
222
+ agreement = xgb.label == bert.label
223
+
224
+ return CompareResponse(
225
+ xgboost=xgb,
226
+ distilbert=bert,
227
+ agreement=agreement,
228
+ )
229
+
230
+
231
+ # ── Training Pipeline ────────────────────────────────────────────────────────
232
+
233
+ @router.post("/pipeline/train", tags=["Pipeline"])
234
+ def trigger_training(
235
+ max_samples: int | None = Query(None, ge=100, description="Cap on dataset rows"),
236
+ model: str = Query("all", description="Which model to train: all, naive, classical, deep"),
237
+ ):
238
+ """
239
+ Trigger the full training pipeline (setup.py).
240
+
241
+ This runs data download β†’ feature extraction β†’ model training.
242
+ May take several minutes depending on dataset size and model choice.
243
+ """
244
+ cmd = ["python", "setup.py"]
245
+ if max_samples:
246
+ cmd.extend(["--max_samples", str(max_samples)])
247
+ if model != "all":
248
+ cmd.extend(["--step", "train", "--model", model])
249
+
250
+ try:
251
+ result = subprocess.run(
252
+ cmd, capture_output=True, text=True, timeout=1800, # 30 min max
253
+ )
254
+ return {
255
+ "message": "Training pipeline completed" if result.returncode == 0 else "Pipeline failed",
256
+ "returncode": result.returncode,
257
+ "stdout": result.stdout[-3000:] if result.stdout else "", # last 3000 chars
258
+ "stderr": result.stderr[-1000:] if result.stderr else "",
259
+ }
260
+ except subprocess.TimeoutExpired:
261
+ raise HTTPException(status_code=504, detail="Training pipeline timed out (30 min limit)")
262
+ except FileNotFoundError:
263
+ raise HTTPException(
264
+ status_code=500,
265
+ detail="setup.py not found. Make sure you're running from the OffRails project root.",
266
+ )
267
+
268
+
269
+ @router.get("/pipeline/status", tags=["Pipeline"])
270
+ def pipeline_status():
271
+ """Check what trained models and data files are available."""
272
+ model_dir = _state["model_dir"]
273
+ data_dir = "data/processed"
274
+
275
+ data_files = {}
276
+ if os.path.isdir(data_dir):
277
+ for f in os.listdir(data_dir):
278
+ path = os.path.join(data_dir, f)
279
+ data_files[f] = {
280
+ "size_mb": round(os.path.getsize(path) / 1_048_576, 2),
281
+ }
282
+
283
+ return {
284
+ "available_models": _check_available_models(model_dir),
285
+ "loaded_model": _state["model_type"],
286
+ "model_dir": model_dir,
287
+ "data_files": data_files,
288
+ }
app/main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Trace Anomaly Detection β€” FastAPI Backend
3
+
4
+ This is the API layer that wraps the ML pipeline built in scripts/.
5
+ All model training, feature extraction, and inference logic lives
6
+ in the partner's code (scripts/inference.py). This file just serves it.
7
+
8
+ Run from the OffRails project root:
9
+ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
10
+
11
+ Interactive docs:
12
+ http://localhost:8000/docs
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import sys
19
+ import logging
20
+
21
+ from fastapi import FastAPI
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+
24
+ # ── Make partner's scripts/ importable ───────────────────────────────────────
25
+ # inference.py does `from model import ...` and `from build_features import ...`
26
+ # so we need scripts/ on sys.path.
27
+ SCRIPTS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "scripts")
28
+ if SCRIPTS_DIR not in sys.path:
29
+ sys.path.insert(0, SCRIPTS_DIR)
30
+
31
+ from app.api.routes import router
32
+
33
+ # ── Logging ──────────────────────────────────────────────────────────────────
34
+
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format="%(asctime)s %(levelname)-8s %(name)s β€” %(message)s",
38
+ datefmt="%H:%M:%S",
39
+ )
40
+
41
+ # ── App ──────────────────────────────────────────────────────────────────────
42
+
43
+ app = FastAPI(
44
+ title="Agent Trace Anomaly Detection API",
45
+ description=(
46
+ "Detects anomalous agent execution traces β€” unnecessary tool calls, "
47
+ "circular reasoning, and goal drift.\n\n"
48
+ "**ML models** (XGBoost, DistilBERT) are trained via the pipeline in `scripts/`.\n"
49
+ "**This API** serves predictions from those trained models.\n\n"
50
+ "## Workflow\n"
51
+ "1. Train models: `python setup.py` (or `POST /pipeline/train`)\n"
52
+ "2. Load a model: `POST /models/load`\n"
53
+ "3. Predict: `POST /predict`\n"
54
+ "4. Compare models: `POST /predict/compare`\n"
55
+ ),
56
+ version="1.0.0",
57
+ )
58
+
59
+ # Allow Gradio / any frontend to call the API
60
+ app.add_middleware(
61
+ CORSMiddleware,
62
+ allow_origins=["*"],
63
+ allow_credentials=True,
64
+ allow_methods=["*"],
65
+ allow_headers=["*"],
66
+ )
67
+
68
+ app.include_router(router)
69
+
70
+
71
+ # ── Root ─────────────────────────────────────────────────────────────────────
72
+
73
+ @app.get("/", include_in_schema=False)
74
+ def root():
75
+ return {
76
+ "service": "Agent Trace Anomaly Detection API",
77
+ "docs": "/docs",
78
+ "workflow": [
79
+ "1. Train models: python setup.py",
80
+ "2. POST /models/load (load xgboost or distilbert)",
81
+ "3. POST /predict (classify a trace)",
82
+ "4. POST /predict/compare (run both models)",
83
+ ],
84
+ }
app/schemas.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for the Agent Trace Anomaly Detection API.
3
+
4
+ Matches the interface defined in scripts/inference.py:
5
+ TraceAnomalyDetector.predict() returns:
6
+ is_anomalous: bool
7
+ confidence: float
8
+ label: int (0=normal, 1=anomalous)
9
+ anomaly_signals: list[str]
10
+ features: dict (xgboost only)
11
+ """
12
+
13
+ from pydantic import BaseModel, Field
14
+ from typing import Optional
15
+
16
+
17
+ # ── Request Schemas ──────────────────────────────────────────────────────────
18
+
19
+ class TraceMessage(BaseModel):
20
+ """
21
+ A single message in an agent execution trace.
22
+ Accepts both ToolBench/ShareGPT format ('from'/'value')
23
+ and OpenAI format ('role'/'content').
24
+ """
25
+ role: Optional[str] = Field(None, alias="from", description="Message role (OpenAI format)")
26
+ value: Optional[str] = Field(None, description="Message content (ShareGPT format)")
27
+ content: Optional[str] = Field(None, description="Message content (OpenAI format)")
28
+
29
+ model_config = {"populate_by_name": True}
30
+
31
+ def to_dict(self) -> dict:
32
+ """Normalize to the format inference.py expects."""
33
+ d = {}
34
+ if self.role is not None:
35
+ d["from"] = self.role
36
+ if self.value is not None:
37
+ d["value"] = self.value
38
+ if self.content is not None:
39
+ d["value"] = self.content # map content β†’ value for ToolBench compat
40
+ if "from" not in d and self.role:
41
+ d["from"] = self.role
42
+ return d
43
+
44
+
45
+ class PredictRequest(BaseModel):
46
+ """Request body for single-trace anomaly prediction."""
47
+ conversations: list[dict] = Field(
48
+ ...,
49
+ description=(
50
+ "List of message dicts in ShareGPT/ToolBench format. "
51
+ "Each dict should have 'from' (role) and 'value' (content) keys."
52
+ ),
53
+ )
54
+
55
+ model_config = {
56
+ "json_schema_extra": {
57
+ "examples": [
58
+ {
59
+ "conversations": [
60
+ {"from": "user", "value": "Find me flights from NYC to London"},
61
+ {"from": "assistant", "value": "I'll search for flights using the travel API."},
62
+ {"from": "function", "value": '{"flights": [{"price": 450}]}'},
63
+ {"from": "assistant", "value": "I found flights starting at $450."},
64
+ ]
65
+ }
66
+ ]
67
+ }
68
+ }
69
+
70
+
71
+ class PredictBatchRequest(BaseModel):
72
+ """Request body for batch prediction on multiple traces."""
73
+ traces: list[list[dict]] = Field(
74
+ ..., description="List of traces, each trace is a list of message dicts"
75
+ )
76
+
77
+
78
+ class ModelLoadRequest(BaseModel):
79
+ """Request to load a specific model type."""
80
+ model_type: str = Field(
81
+ "xgboost",
82
+ description="Model to load: 'xgboost' or 'distilbert'",
83
+ pattern="^(xgboost|distilbert)$",
84
+ )
85
+ model_dir: str = Field("models", description="Path to saved models directory")
86
+
87
+
88
+ # ── Response Schemas ─────────────────────────────────────────────────────────
89
+
90
+ class PredictResponse(BaseModel):
91
+ """Response from the anomaly detector β€” mirrors TraceAnomalyDetector.predict() output."""
92
+ is_anomalous: bool = Field(..., description="True if the trace is predicted anomalous")
93
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Probability of anomaly")
94
+ label: int = Field(..., description="0 = normal, 1 = anomalous")
95
+ anomaly_signals: list[str] = Field(
96
+ default_factory=list,
97
+ description="Human-readable explanations of detected anomaly patterns",
98
+ )
99
+ model_used: str = Field(..., description="Which model produced this prediction")
100
+ features: Optional[dict] = Field(
101
+ None, description="Extracted feature values (xgboost only)"
102
+ )
103
+
104
+
105
+ class PredictBatchResponse(BaseModel):
106
+ """Response for batch predictions."""
107
+ predictions: list[PredictResponse]
108
+ anomaly_count: int
109
+ total: int
110
+
111
+
112
+ class CompareResponse(BaseModel):
113
+ """Side-by-side prediction from both models on the same trace."""
114
+ xgboost: Optional[PredictResponse] = None
115
+ distilbert: Optional[PredictResponse] = None
116
+ agreement: bool = Field(..., description="Whether both models agree on the label")
117
+
118
+
119
+ class HealthResponse(BaseModel):
120
+ status: str
121
+ loaded_model: Optional[str] = None
122
+ available_models: list[str]
123
+ model_dir: str