readCtrl_lambda / code /fine_tune_sft_dpo /support_claim_api.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
"""
FastAPI service for support claim checking using HHEM model.
This service provides an API endpoint to check if subclaims are supported by context.
"""
import os
import sys
from typing import List, Dict, Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import warnings
warnings.filterwarnings("ignore")
try:
import torch
from transformers import AutoModelForSequenceClassification
_HHEM_AVAILABLE = True
except ImportError:
torch = None
AutoModelForSequenceClassification = None
_HHEM_AVAILABLE = False
# --- HHEM (vectara/hallucination_evaluation_model) for support checking ---
HHEM_MODEL_NAME = os.getenv("HHEM_MODEL_NAME", "vectara/hallucination_evaluation_model")
_HHEM_MODEL = None
def load_hhem_model(model_name: str = None):
"""Load the HHEM model for subclaim verification (premise=generated text, hypothesis=subclaim)."""
global _HHEM_MODEL
if not _HHEM_AVAILABLE:
raise RuntimeError("torch and transformers are required for HHEM support checking")
if _HHEM_MODEL is not None:
return _HHEM_MODEL
name = model_name or HHEM_MODEL_NAME
_HHEM_MODEL = AutoModelForSequenceClassification.from_pretrained(
name,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
_HHEM_MODEL.eval()
return _HHEM_MODEL
def verify_subclaims_in_text(
model,
generated_text: str,
subclaims: List[str],
threshold: float = 0.5,
batch_size: int = 32,
) -> List[Dict[str, Any]]:
"""
Verify how much information from subclaims exists in generated text.
HHEM: premise=generated text, hypothesis=subclaim. Returns PASS/FAIL per subclaim.
"""
pairs = [(generated_text, claim) for claim in subclaims]
results = []
for i in range(0, len(pairs), batch_size):
batch_pairs = pairs[i : i + batch_size]
batch_scores = model.predict(batch_pairs)
for j, score in enumerate(batch_scores):
claim_index = i + j
claim = subclaims[claim_index]
s = score.item() if hasattr(score, "item") else float(score)
results.append({
"subclaim": claim,
"score": round(s, 4),
"status": "PASS" if s > threshold else "FAIL",
"exists_in_text": s > threshold,
})
return results
# FastAPI app
app = FastAPI(title="Support Claim Checking API", version="1.0.0")
class SupportCheckRequest(BaseModel):
"""Request model for support claim checking."""
context: str
subclaims: List[str]
threshold: float = 0.5
batch_size: int = 32
class SupportCheckResponse(BaseModel):
"""Response model for support claim checking."""
labels: List[str] # "supported" | "not_supported" | "invalid"
details: List[Dict[str, Any]] # Detailed results with scores
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"hhem_available": _HHEM_AVAILABLE,
"model_loaded": _HHEM_MODEL is not None
}
@app.post("/check_support", response_model=SupportCheckResponse)
async def check_support(request: SupportCheckRequest):
"""
Check if subclaims are supported by the context.
Args:
request: SupportCheckRequest containing context, subclaims, threshold, and batch_size
Returns:
SupportCheckResponse with labels and detailed results
"""
if not request.context or not request.subclaims:
return SupportCheckResponse(
labels=[],
details=[]
)
if not _HHEM_AVAILABLE:
return SupportCheckResponse(
labels=["invalid"] * len(request.subclaims),
details=[]
)
try:
model = load_hhem_model()
results = verify_subclaims_in_text(
model,
request.context,
request.subclaims,
threshold=request.threshold,
batch_size=request.batch_size,
)
# Map PASS -> "supported", FAIL -> "not_supported" to match existing reward logic
labels = ["supported" if r["status"] == "PASS" else "not_supported" for r in results]
return SupportCheckResponse(
labels=labels,
details=results
)
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"HHEM support check failed: {str(exc)}"
)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("SUPPORT_API_PORT", "8091"))
host = os.getenv("SUPPORT_API_HOST", "0.0.0.0")
uvicorn.run(app, host=host, port=port)