gng / agents.py
plexdx's picture
Upload 21 files
f589dab verified
"""
agents.py β€” Prefect @flow orchestrating two concurrent LiteLLM tasks.
Flow topology:
run_analysis_flow(claim, platform, rag_context)
β”‚
β”œβ”€β”€ misinformation_task() ←── Groq / mixtral-8x7b-32768
β”‚ verdict: green | yellow | red
β”‚
└── hallucination_task() ←── Anthropic Claude Haiku (AI platforms only)
verdict: purple | green
β”‚
└── merge_results() β†’ AnalysisResult (higher severity wins)
Severity order: red > purple > yellow > green
Why LiteLLM as abstraction:
- Single .completion() call works across Groq, Anthropic, OpenAI, Ollama
- Automatic retry with provider-level fallbacks
- No code change to swap providers
"""
from __future__ import annotations
import asyncio
import os
from dataclasses import dataclass
from typing import Literal
import litellm
import structlog
from pydantic import BaseModel, Field, field_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from rag_pipeline import RagContext
log = structlog.get_logger(__name__)
# Silence LiteLLM's verbose logs unless explicitly enabled
litellm.set_verbose = os.getenv("LITELLM_VERBOSE", "false").lower() == "true"
# ── Color severity ordering ────────────────────────────────────────────────────
SEVERITY: dict[str, int] = {"green": 0, "yellow": 1, "purple": 2, "red": 3}
COLOR_TYPE = Literal["green", "yellow", "red", "purple"]
# AI-interface platforms that trigger the hallucination agent
AI_PLATFORMS = {"chatgpt", "claude", "gemini", "openai", "ai_chat", "bard", "copilot"}
# ── Result models ─────────────────────────────────────────────────────────────
class AgentOutput(BaseModel):
color: COLOR_TYPE
confidence: int = Field(ge=0, le=100)
verdict: str = Field(max_length=120)
explanation: str = Field(max_length=600)
sources: list[str] = Field(default_factory=list, max_length=5)
@field_validator("color", mode="before")
@classmethod
def normalize_color(cls, v: str) -> str:
"""Coerce LLM output to valid color string."""
v = str(v).lower().strip()
if v not in SEVERITY:
return "yellow"
return v
@field_validator("confidence", mode="before")
@classmethod
def clamp_confidence(cls, v) -> int:
return max(0, min(100, int(v)))
@dataclass
class AnalysisFlowResult:
color: str
confidence: int
verdict: str
explanation: str
sources: list[str]
# ── System prompts ─────────────────────────────────────────────────────────────
MISINFORMATION_SYSTEM = """You are a veteran fact-checking analyst. Given a claim and retrieved evidence, determine whether the claim is true, misleading, or false.
You MUST output ONLY valid JSON matching this exact schema:
{"color": "green"|"yellow"|"red", "confidence": 0-100, "verdict": "<10-word label>", "explanation": "<2-3 sentences>", "sources": ["<url1>", "<url2>", "<url3>"]}
Color logic:
- green: Widely corroborated, verified, factually sound
- yellow: Breaking/unverified, weak evidence, contested but not proven false
- red: Debunked by multiple independent sources, intentional deceit, or contradicts established consensus
Base your confidence on the quality and quantity of retrieved evidence. Use only URLs from the evidence β€” never fabricate URLs."""
HALLUCINATION_SYSTEM = """You are an AI output auditor specializing in detecting LLM hallucinations. Analyze AI-generated text for:
1. Fabricated citations (URLs, paper titles, author names that don't match real publications)
2. Statistical impossibilities (numbers that cannot logically be correct)
3. Internal contradictions (statements that contradict each other in the same passage)
4. Knowledge cutoff violations (claiming events that postdate the model's training)
You MUST output ONLY valid JSON:
{"color": "purple"|"green", "confidence": 0-100, "verdict": "<10-word label>", "explanation": "<2-3 sentences describing the specific hallucination type>", "sources": []}
purple = hallucination detected with high probability
green = no hallucination detected"""
# ── LiteLLM call wrapper ───────────────────────────────────────────────────────
@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=0.1, min=0.1, max=1.0),
reraise=False,
)
async def _call_llm(
model: str,
system: str,
user_content: str,
max_tokens: int = 400,
) -> str | None:
"""
Thin async wrapper around litellm.acompletion.
Returns the response text or None on failure.
"""
try:
response = await litellm.acompletion(
model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user_content},
],
temperature=0.1,
max_tokens=max_tokens,
response_format={"type": "json_object"},
)
return response.choices[0].message.content
except Exception as exc:
log.warning("llm.call_failed", model=model, error=str(exc))
return None
def _parse_agent_output(raw: str | None, fallback_color: COLOR_TYPE) -> AgentOutput:
"""Parse LLM JSON response with graceful fallback."""
if raw is None:
return AgentOutput(
color=fallback_color,
confidence=40,
verdict="Analysis unavailable",
explanation="LLM service temporarily unavailable. Result based on heuristics only.",
sources=[],
)
try:
import json, re
# Strip any accidental markdown fences
cleaned = re.sub(r"```(?:json)?|```", "", raw).strip()
data = json.loads(cleaned)
return AgentOutput.model_validate(data)
except Exception as exc:
log.warning("agent.parse_error", error=str(exc), raw=raw[:200])
return AgentOutput(
color=fallback_color,
confidence=35,
verdict="Parse error",
explanation=f"Could not parse agent response. Raw snippet: {raw[:100]}",
sources=[],
)
# ── Individual tasks ───────────────────────────────────────────────────────────
async def misinformation_task(
claim_text: str,
rag_context: RagContext,
) -> AgentOutput:
"""
Uses mixtral-8x7b-32768 via Groq for high-throughput misinformation detection.
Falls back to llama3-8b-8192 if mixtral quota is exceeded.
"""
# Build concise evidence summary from top-3 RAG docs
evidence_lines = []
for i, doc in enumerate(rag_context.retrieved_docs[:3], 1):
evidence_lines.append(
f"{i}. [{doc.domain}] (score:{doc.score:.2f}) {doc.text[:180]}\n URL: {doc.source_url}"
)
evidence_block = "\n".join(evidence_lines) if evidence_lines else "No retrieved evidence."
user_content = (
f"CLAIM: {claim_text}\n\n"
f"TRUST SCORE: {rag_context.trust_score:.2f} "
f"(community_note={rag_context.community_note}, "
f"corroborations={rag_context.corroboration_count})\n\n"
f"RETRIEVED EVIDENCE:\n{evidence_block}"
)
# Prefer Groq's Mixtral; fallback model chain
groq_key = os.getenv("GROQ_API_KEY", "")
model = f"groq/mixtral-8x7b-32768" if groq_key else "openai/gpt-4o-mini"
raw = await _call_llm(model=model, system=MISINFORMATION_SYSTEM, user_content=user_content)
# If primary model fails, try secondary
if raw is None and groq_key:
raw = await _call_llm(
model="groq/llama3-8b-8192",
system=MISINFORMATION_SYSTEM,
user_content=user_content,
)
output = _parse_agent_output(raw, fallback_color="yellow")
# Override: community notes are strong red signals
if rag_context.community_note and output.color != "red":
output.color = "red"
output.confidence = max(output.confidence, 75)
output.explanation = f"⚠ Active Community Note. {output.explanation}"
# Override: low trust score combined with no corroboration β†’ yellow floor
if rag_context.trust_score < 0.3 and rag_context.corroboration_count == 0:
if output.color == "green":
output.color = "yellow"
output.confidence = min(output.confidence, 55)
log.info(
"misinformation_task.done",
color=output.color,
confidence=output.confidence,
model=model,
)
return output
async def hallucination_task(claim_text: str) -> AgentOutput:
"""
Runs only for AI chat platform sources.
Uses Claude Haiku for superior hallucination pattern recognition.
Falls back to Groq llama3 if Anthropic key is absent.
"""
anthropic_key = os.getenv("ANTHROPIC_API_KEY", "")
model = "claude-haiku-4-5-20251001" if anthropic_key else "groq/llama3-8b-8192"
raw = await _call_llm(
model=model,
system=HALLUCINATION_SYSTEM,
user_content=f"Audit this AI-generated text for hallucinations:\n\n{claim_text}",
max_tokens=300,
)
output = _parse_agent_output(raw, fallback_color="purple")
log.info(
"hallucination_task.done",
color=output.color,
confidence=output.confidence,
model=model,
)
return output
def _merge_results(
misinfo: AgentOutput,
hallucination: AgentOutput | None,
) -> AnalysisFlowResult:
"""
Severity-based merge: pick the higher-severity color.
Purple (AI hallucination) and Red (misinformation) are both max severity
but represent different categories β€” red wins if both fire.
"""
if hallucination is None:
winner = misinfo
else:
winner = misinfo if SEVERITY[misinfo.color] >= SEVERITY[hallucination.color] else hallucination
return AnalysisFlowResult(
color=winner.color,
confidence=winner.confidence,
verdict=winner.verdict,
explanation=winner.explanation,
sources=winner.sources,
)
# ── Main flow (replaces Prefect decorator for HF compatibility) ────────────────
async def run_analysis_flow(
claim_text: str,
claim_hash: str,
platform: str,
rag_context: RagContext,
) -> AnalysisFlowResult:
"""
Orchestrates concurrent agent tasks.
On Hugging Face Spaces, Prefect's scheduler is replaced by asyncio.gather
for zero-dependency concurrent execution. For production Prefect deployment,
wrap each inner call with @task decorator.
"""
is_ai_platform = platform.lower() in AI_PLATFORMS
if is_ai_platform:
# Run both tasks concurrently
misinfo_coro = misinformation_task(claim_text, rag_context)
halluc_coro = hallucination_task(claim_text)
misinfo_result, halluc_result = await asyncio.gather(
misinfo_coro, halluc_coro
)
else:
misinfo_result = await misinformation_task(claim_text, rag_context)
halluc_result = None
merged = _merge_results(misinfo_result, halluc_result)
log.info(
"flow.complete",
claim_hash=claim_hash[:8],
color=merged.color,
confidence=merged.confidence,
platform=platform,
ai_platform=is_ai_platform,
)
return merged