|
|
|
|
|
""" |
|
|
CyberForge ML Inference Module |
|
|
Backend integration for mlService.js |
|
|
""" |
|
|
|
|
|
import json |
|
|
import time |
|
|
import joblib |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any, Optional |
|
|
|
|
|
class CyberForgeInference: |
|
|
""" |
|
|
ML inference service for CyberForge backend. |
|
|
Compatible with mlService.js API contract. |
|
|
""" |
|
|
|
|
|
def __init__(self, models_dir: str): |
|
|
self.models_dir = Path(models_dir) |
|
|
self.loaded_models = {} |
|
|
self.manifest = self._load_manifest() |
|
|
|
|
|
def _load_manifest(self) -> Dict: |
|
|
manifest_path = self.models_dir / "manifest.json" |
|
|
if manifest_path.exists(): |
|
|
with open(manifest_path) as f: |
|
|
return json.load(f) |
|
|
return {"models": {}} |
|
|
|
|
|
def load_model(self, model_name: str) -> bool: |
|
|
"""Load a model into memory""" |
|
|
if model_name in self.loaded_models: |
|
|
return True |
|
|
|
|
|
model_dir = self.models_dir / model_name |
|
|
model_path = model_dir / "model.pkl" |
|
|
scaler_path = model_dir / "scaler.pkl" |
|
|
|
|
|
if not model_path.exists(): |
|
|
return False |
|
|
|
|
|
self.loaded_models[model_name] = { |
|
|
"model": joblib.load(model_path), |
|
|
"scaler": joblib.load(scaler_path) if scaler_path.exists() else None |
|
|
} |
|
|
return True |
|
|
|
|
|
def predict(self, model_name: str, features: Dict) -> Dict: |
|
|
""" |
|
|
Make a prediction. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the model to use |
|
|
features: Feature dictionary |
|
|
|
|
|
Returns: |
|
|
Response matching mlService.js contract |
|
|
""" |
|
|
if not self.load_model(model_name): |
|
|
return {"error": f"Model not found: {model_name}"} |
|
|
|
|
|
model_data = self.loaded_models[model_name] |
|
|
model = model_data["model"] |
|
|
scaler = model_data["scaler"] |
|
|
|
|
|
|
|
|
X = np.array([list(features.values())]) |
|
|
|
|
|
|
|
|
if scaler: |
|
|
X = scaler.transform(X) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
prediction = int(model.predict(X)[0]) |
|
|
inference_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
confidence = 0.5 |
|
|
if hasattr(model, "predict_proba"): |
|
|
proba = model.predict_proba(X)[0] |
|
|
confidence = float(max(proba)) |
|
|
|
|
|
|
|
|
risk_level = ( |
|
|
"critical" if confidence >= 0.9 else |
|
|
"high" if confidence >= 0.7 else |
|
|
"medium" if confidence >= 0.5 else |
|
|
"low" if confidence >= 0.3 else "info" |
|
|
) |
|
|
|
|
|
return { |
|
|
"prediction": prediction, |
|
|
"confidence": confidence, |
|
|
"risk_level": risk_level, |
|
|
"model_name": model_name, |
|
|
"model_version": "1.0.0", |
|
|
"inference_time_ms": inference_time |
|
|
} |
|
|
|
|
|
def batch_predict(self, model_name: str, features_list: List[Dict]) -> List[Dict]: |
|
|
"""Batch predictions""" |
|
|
return [self.predict(model_name, f) for f in features_list] |
|
|
|
|
|
def list_models(self) -> List[str]: |
|
|
"""List available models""" |
|
|
return list(self.manifest.get("models", {}).keys()) |
|
|
|
|
|
def get_model_info(self, model_name: str) -> Dict: |
|
|
"""Get model information""" |
|
|
return self.manifest.get("models", {}).get(model_name, {}) |
|
|
|
|
|
|
|
|
|
|
|
def create_api(models_dir: str): |
|
|
"""Create FastAPI app for model serving""" |
|
|
try: |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
except ImportError: |
|
|
return None |
|
|
|
|
|
app = FastAPI(title="CyberForge ML API", version="1.0.0") |
|
|
inference = CyberForgeInference(models_dir) |
|
|
|
|
|
class PredictRequest(BaseModel): |
|
|
model_name: str |
|
|
features: Dict |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(request: PredictRequest): |
|
|
result = inference.predict(request.model_name, request.features) |
|
|
if "error" in result: |
|
|
raise HTTPException(status_code=404, detail=result["error"]) |
|
|
return result |
|
|
|
|
|
@app.get("/models") |
|
|
async def list_models(): |
|
|
return {"models": inference.list_models()} |
|
|
|
|
|
@app.get("/models/{model_name}") |
|
|
async def get_model_info(model_name: str): |
|
|
info = inference.get_model_info(model_name) |
|
|
if not info: |
|
|
raise HTTPException(status_code=404, detail="Model not found") |
|
|
return info |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
models_dir = sys.argv[1] if len(sys.argv) > 1 else "." |
|
|
|
|
|
inference = CyberForgeInference(models_dir) |
|
|
print(f"Available models: {inference.list_models()}") |
|
|
|