anomalyOS / src /llm.py
CaffeinatedCoding's picture
Upload folder using huggingface_hub
e72f783 verified
# src/llm.py
# Groq LLM call with tenacity retry
# Single call per inference β€” not a multi-step chain
# Non-blocking: queued as FastAPI BackgroundTask, polled via /report/{id}
import os
import json
import time
import uuid
import httpx
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type
)
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
GROQ_MODEL = "llama-3.3-70b-versatile"
MAX_TOKENS = 512
# In-memory report store: report_id β†’ {status, report}
# FastAPI polls this via GET /report/{report_id}
_report_store: dict = {}
class LLMAPIError(Exception):
pass
def _build_prompt(category: str,
anomaly_score: float,
similar_cases: list,
graph_context: dict) -> list:
"""
Build LLM messages list.
Strictly grounded β€” model must cite case IDs, cannot use outside knowledge.
One call per inference. Context = retrieved cases + graph context.
"""
system = (
"You are an industrial quality control assistant. "
"Answer ONLY based on the retrieved cases and graph context provided. "
"Do not use outside knowledge. "
"Always cite the Case ID when referencing a case. "
"Be concise β€” 3 to 5 sentences maximum."
)
# Build context block from retrieved similar cases
context_lines = []
for i, case in enumerate(similar_cases[:5]):
context_lines.append(
f"[Case {i+1}: category={case.get('category')}, "
f"defect={case.get('defect_type')}, "
f"similarity={case.get('similarity_score', 0):.3f}]"
)
# Add graph context
root_causes = graph_context.get("root_causes", [])
remediations = graph_context.get("remediations", [])
if root_causes:
context_lines.append(f"Root causes: {', '.join(root_causes)}")
if remediations:
context_lines.append(f"Remediations: {', '.join(remediations)}")
context_str = "\n".join(context_lines) if context_lines else "No context available."
user_msg = (
f"CONTEXT:\n{context_str}\n\n"
f"QUERY: Image anomaly score {anomaly_score:.3f}. "
f"Category: {category}. "
f"Describe the likely defect, root cause, and recommended action."
f"\n\nREPORT:"
)
return [
{"role": "system", "content": system},
{"role": "user", "content": user_msg}
]
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=8),
retry=retry_if_exception_type(LLMAPIError),
reraise=True
)
def _call_groq(messages: list) -> str:
"""
Single Groq API call with tenacity retry.
Retries 3 times with 2s/4s/8s backoff on failure.
Raises LLMAPIError if all 3 attempts fail.
"""
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
raise LLMAPIError("GROQ_API_KEY not set in environment")
try:
with httpx.Client(timeout=30.0) as client:
response = client.post(
GROQ_API_URL,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": GROQ_MODEL,
"messages": messages,
"max_tokens": MAX_TOKENS,
"temperature": 0.3 # low temp = factual, grounded
}
)
if response.status_code == 429:
raise LLMAPIError("Groq rate limit hit")
if response.status_code != 200:
raise LLMAPIError(f"Groq API error {response.status_code}: "
f"{response.text[:200]}")
data = response.json()
content = data["choices"][0]["message"]["content"].strip()
if not content:
raise LLMAPIError("Groq returned empty response")
return content
except httpx.TimeoutException:
raise LLMAPIError("Groq API timeout")
except httpx.RequestError as e:
raise LLMAPIError(f"Groq request failed: {e}")
def queue_report(category: str,
anomaly_score: float,
similar_cases: list,
graph_context: dict) -> str:
"""
Queue an LLM report generation.
Returns report_id immediately β€” report generated asynchronously.
Frontend polls GET /report/{report_id} every 500ms.
"""
report_id = str(uuid.uuid4())
_report_store[report_id] = {"status": "pending", "report": None}
return report_id
def generate_report(report_id: str,
category: str,
anomaly_score: float,
similar_cases: list,
graph_context: dict):
"""
Called as FastAPI BackgroundTask.
Generates report and stores in _report_store under report_id.
"""
try:
messages = _build_prompt(category, anomaly_score,
similar_cases, graph_context)
report = _call_groq(messages)
_report_store[report_id] = {"status": "ready", "report": report}
except LLMAPIError as e:
fallback = (
"LLM temporarily unavailable. "
"Retrieved cases and graph context are shown above. "
f"(Error: {str(e)[:100]})"
)
_report_store[report_id] = {"status": "ready", "report": fallback}
except Exception as e:
_report_store[report_id] = {
"status": "ready",
"report": "Could not generate report. Please retry."
}
def get_report(report_id: str) -> dict:
"""
Poll report status.
Returns: {status: pending} or {status: ready, report: "..."}
"""
return _report_store.get(
report_id,
{"status": "not_found", "report": None}
)
def cleanup_old_reports(max_age_seconds: int = 3600):
"""Prevent _report_store growing unbounded. Called periodically."""
# Simple approach: keep only last 500 reports
if len(_report_store) > 500:
keys = list(_report_store.keys())
for key in keys[:250]:
del _report_store[key]