pdf-trainer-worker / backend /worker /openai_classifier.py
Avinashnalla7's picture
fix: restore worker code + entrypoint
7fd3f6f
from __future__ import annotations
import base64
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from openai import OpenAI
# ----------------------------
# Known templates (mirror your main system)
# ----------------------------
KNOWN_TEMPLATES: List[Dict[str, Any]] = [
{
"template_id": "T1_IFACTOR_DELIVERED_ORDER",
"name": "I-FACTOR Delivered Order Form",
"keywords_all": ["delivered order form"],
"keywords_any": ["i-factor", "cerapedics", "product information", "stickers", "bill to", "delivered to"],
},
{
"template_id": "T2_SEASPINE_DELIVERED_GOODS_FORM",
"name": "SeaSpine Delivered Goods Form",
"keywords_all": ["delivered goods form"],
"keywords_any": ["seaspine", "isotis", "handling fee", "sales order", "invoice"],
},
{
"template_id": "T3_ASTURA_SALES_ORDER_FORM",
"name": "Astura Sales Order Form",
"keywords_all": [],
"keywords_any": ["astura", "dc141", "ca200", "cbba", "sales order"],
},
{
"template_id": "T4_MEDICAL_ESTIMATION_OF_CHARGES",
"name": "Medical Estimation of Charges",
"keywords_all": [],
"keywords_any": ["estimation of charges", "good faith estimate", "patient responsibility", "insurance"],
},
{
"template_id": "T5_CLINICAL_PROGRESS_NOTE_POSTOP",
"name": "Clinical Progress Note Postop",
"keywords_all": [],
"keywords_any": ["clinical progress note", "progress note", "post-op", "assessment", "plan"],
},
{
"template_id": "T6_CUSTOMER_CHARGE_SHEET_SPINE",
"name": "Customer Charge Sheet Spine",
"keywords_all": [],
"keywords_any": ["customer charge sheet", "charge sheet", "spine", "qty", "unit price", "total"],
},
{
"template_id": "T7_SALES_ORDER_ZIMMER",
"name": "Zimmer Sales Order",
"keywords_all": [],
"keywords_any": ["zimmer", "zimmer biomet", "biomet", "sales order", "purchase order", "po number"],
},
]
# ----------------------------
# Public API (EXPLICIT key/model)
# ----------------------------
def classify_with_openai(
image_paths: List[str],
*,
api_key: str,
model: str,
max_pages: int = 2,
) -> Dict[str, Any]:
"""
Input: list of PNG file paths (page renders).
Output:
{
"template_id": "T1_..." OR "UNKNOWN",
"confidence": 0..1,
"reason": "short string",
"trainer_schema": {} # reserved for later
}
Hard guarantees:
- does NOT read environment variables
- does NOT guess api keys
- strict normalization to known template_ids
"""
api_key = (api_key or "").strip()
model = (model or "").strip()
if not api_key:
raise RuntimeError("classify_with_openai: api_key is empty")
if not model:
raise RuntimeError("classify_with_openai: model is empty")
if not image_paths:
return {
"template_id": "UNKNOWN",
"confidence": 0.0,
"reason": "No rendered images provided.",
"trainer_schema": {},
}
# Encode first N pages (keep small + deterministic)
pages_b64: List[str] = []
for p in image_paths[: max_pages if max_pages > 0 else 1]:
pages_b64.append(_png_file_to_b64(Path(p)))
client = OpenAI(api_key=api_key)
system = (
"You are a strict document template classifier.\n"
"You will be shown PNG images of PDF pages (scanned forms).\n"
"Your job is to decide which known template matches.\n\n"
"Hard rules:\n"
"1) Output VALID JSON only. No markdown. No extra text.\n"
"2) Choose ONE template_id from the provided list OR return template_id='UNKNOWN'.\n"
"3) If uncertain, return UNKNOWN.\n"
"4) Use printed headers, vendor branding, and distinctive layout cues.\n"
"5) confidence must be 0..1.\n"
)
prompt_payload = {
"known_templates": KNOWN_TEMPLATES,
"output_schema": {
"template_id": "string (one of known template_ids) OR 'UNKNOWN'",
"confidence": "number 0..1",
"reason": "short string",
},
}
user_text = (
"Classify the attached document images against known_templates.\n"
"Return JSON matching output_schema.\n\n"
f"{json.dumps(prompt_payload, indent=2)}"
)
# Multi-modal message: text + images
content: List[Dict[str, Any]] = [{"type": "text", "text": user_text}]
for b64png in pages_b64:
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64png}"},
}
)
resp = client.chat.completions.create(
model=model,
temperature=0.0,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": content},
],
)
raw = (resp.choices[0].message.content or "").strip()
parsed = _parse_json_object(raw)
template_id = str(parsed.get("template_id") or "").strip()
confidence = _to_float(parsed.get("confidence"), default=0.0)
confidence = max(0.0, min(1.0, confidence))
reason = str(parsed.get("reason") or "").strip()
# Normalize: only allow known template ids or UNKNOWN
template_id = _normalize_template_id(template_id)
# If model returns UNKNOWN but gives high confidence, clamp confidence.
if template_id == "UNKNOWN" and confidence > 0.6:
confidence = 0.6
return {
"template_id": template_id,
"confidence": confidence,
"reason": reason[:500],
"trainer_schema": {},
}
# ----------------------------
# Legacy wrapper (ENV-based) - keep only if you want
# ----------------------------
def classify_with_openai_from_env(image_paths: List[str]) -> Dict[str, Any]:
"""
Backwards compatible wrapper.
Reads env vars, then calls classify_with_openai(api_key=..., model=...).
Use this only if you have old code you haven't updated yet.
"""
import os
api_key = (os.getenv("OPENAI_API_KEY_TEST") or os.getenv("OPENAI_API_KEY") or "").strip()
if not api_key:
raise RuntimeError("Missing OPENAI_API_KEY_TEST (or OPENAI_API_KEY)")
model = (os.getenv("OPENAI_MODEL") or "gpt-4o-mini").strip()
# IMPORTANT: call the explicit version (one implementation only)
return classify_with_openai(
image_paths,
api_key=api_key,
model=model,
)
# ----------------------------
# Helpers
# ----------------------------
def _normalize_template_id(template_id: str) -> str:
tid = (template_id or "").strip()
if not tid:
return "UNKNOWN"
known_ids = {t["template_id"] for t in KNOWN_TEMPLATES}
if tid in known_ids:
return tid
# common garbage patterns (model returns name instead of id, etc.)
low = tid.lower()
for t in KNOWN_TEMPLATES:
if t["name"].lower() == low:
return t["template_id"]
return "UNKNOWN"
def _png_file_to_b64(path: Path) -> str:
data = path.read_bytes()
return base64.b64encode(data).decode("utf-8")
_JSON_BLOCK_RE = re.compile(r"\{.*\}", re.DOTALL)
def _parse_json_object(text: str) -> Dict[str, Any]:
"""
Extract and parse the first {...} JSON object from model output.
Handles:
- pure JSON
- JSON embedded in text
- fenced code blocks (we strip fences)
"""
if not text:
return {}
s = text.strip()
# Strip ```json fences if present
s = _strip_code_fences(s)
# Fast path: starts with "{"
if s.startswith("{"):
try:
return json.loads(s)
except Exception:
pass
# Try to find a JSON-looking block
m = _JSON_BLOCK_RE.search(s)
if not m:
return {}
chunk = m.group(0)
try:
return json.loads(chunk)
except Exception:
# last attempt: remove trailing commas (common model mistake)
cleaned = _remove_trailing_commas(chunk)
try:
return json.loads(cleaned)
except Exception:
return {}
def _strip_code_fences(s: str) -> str:
# remove leading ```json / ``` and trailing ```
if s.startswith("```"):
s = re.sub(r"^```[a-zA-Z0-9]*\s*", "", s)
s = re.sub(r"\s*```$", "", s)
return s.strip()
def _remove_trailing_commas(s: str) -> str:
# naive but effective: remove ",}" and ",]" patterns repeatedly
prev = None
cur = s
while prev != cur:
prev = cur
cur = re.sub(r",\s*}", "}", cur)
cur = re.sub(r",\s*]", "]", cur)
return cur
def _to_float(x: Any, default: float = 0.0) -> float:
try:
return float(x)
except Exception:
return default
# ----------------------------
# Optional: quick self-check (manual)
# ----------------------------
def _debug_summarize_result(res: Dict[str, Any]) -> str:
return f"template_id={res.get('template_id')} conf={res.get('confidence')} reason={str(res.get('reason') or '')[:80]}"
def _validate_known_templates() -> Tuple[bool, str]:
ids = [t.get("template_id") for t in KNOWN_TEMPLATES]
if any(not i for i in ids):
return False, "One or more templates missing template_id"
if len(set(ids)) != len(ids):
return False, "Duplicate template_id in KNOWN_TEMPLATES"
return True, "ok"