Spaces:
Sleeping
Sleeping
| """ | |
| Combined FastAPI + Gradio application for the Invoice Exception Handler. | |
| Serves both the HTTP API endpoints (for the OpenEnv validator) and an | |
| interactive Gradio UI (for judges and exploration) on port 7860. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import threading | |
| from typing import Any, Dict, Optional | |
| import gradio as gr | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from fastapi.responses import JSONResponse | |
| from env import InvoiceExceptionEnv, Action, ActionType, ALL_TASKS | |
| # --------------------------------------------------------------------------- | |
| # Shared environment instance | |
| # --------------------------------------------------------------------------- | |
| env = InvoiceExceptionEnv(seed=42) | |
| env_lock = threading.Lock() | |
| # --------------------------------------------------------------------------- | |
| # FastAPI server | |
| # --------------------------------------------------------------------------- | |
| api = FastAPI(title="Invoice Exception Handler OpenEnv", version="1.0.0") | |
| async def http_reset(body: dict = {}) -> JSONResponse: | |
| """Reset the environment. Optionally specify task_id.""" | |
| with env_lock: | |
| task_id = body.get("task_id", None) | |
| obs = env.reset(task_id) | |
| return JSONResponse(obs.model_dump(mode="json")) | |
| async def http_step(body: dict = {}) -> JSONResponse: | |
| """Execute one action.""" | |
| with env_lock: | |
| result = env.step(body) | |
| return JSONResponse(result.model_dump(mode="json")) | |
| async def http_state() -> JSONResponse: | |
| """Return the current state without advancing.""" | |
| with env_lock: | |
| return JSONResponse(env.state().model_dump(mode="json")) | |
| async def http_grade() -> JSONResponse: | |
| """Grade the current episode.""" | |
| with env_lock: | |
| return JSONResponse(env.grade()) | |
| async def http_tasks() -> JSONResponse: | |
| """List available tasks.""" | |
| return JSONResponse(ALL_TASKS) | |
| async def health() -> JSONResponse: | |
| """Health check endpoint.""" | |
| return JSONResponse({"status": "ok", "version": "1.0.0"}) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI — environment for interactive play | |
| # --------------------------------------------------------------------------- | |
| # Per-session environment for the Gradio UI (separate from the API env) | |
| ui_env = InvoiceExceptionEnv(seed=42) | |
| ui_history: list = [] | |
| def reset_task(task_name: str) -> tuple: | |
| """Reset the environment with the selected task.""" | |
| global ui_history | |
| ui_history = [] | |
| task_map = { | |
| "Task 1 — Price Variance (Easy)": "task1_price_variance", | |
| "Task 2 — Duplicate Tax (Medium)": "task2_duplicate_tax", | |
| "Task 3 — Compound Fraud (Hard)": "task3_compound_fraud", | |
| } | |
| task_id = task_map.get(task_name, "task1_price_variance") | |
| obs = ui_env.reset(task_id) | |
| flag_text = f"**{obs.exception_flag.flag_code}**: {obs.exception_flag.flag_description}" | |
| checks_text = ", ".join(obs.available_checks) | |
| rules_text = ", ".join(obs.available_rules) | |
| kb_text = "\n".join(f"- {entry}" for entry in obs.knowledge_base) | |
| status_text = f"Step: {obs.step_number} | Status: {obs.case_status.value} | Reward: {obs.cumulative_reward:.2f}" | |
| return flag_text, checks_text, rules_text, kb_text, status_text, "", "" | |
| def execute_action(action_type: str, param1: str, param2: str, param3: str) -> tuple: | |
| """Execute a single action and return updated state.""" | |
| global ui_history | |
| params: Dict[str, Any] = {} | |
| if action_type == "inspect_field": | |
| params = {"document": param1, "field": param2} | |
| elif action_type == "cross_check": | |
| params = {"field": param1, "doc_a": param2, "doc_b": param3} | |
| elif action_type == "run_check": | |
| params = {"check_name": param1} | |
| elif action_type == "query_supplier": | |
| params = {"question": param1, "channel": param2 or "phone"} | |
| elif action_type == "query_internal": | |
| params = {"department": param1, "question": param2} | |
| elif action_type == "apply_rule": | |
| params = {"rule_id": param1} | |
| elif action_type == "make_decision": | |
| params = {"decision": param1, "reason": param2} | |
| elif action_type == "route_to": | |
| params = {"team": param1, "notes": param2} | |
| elif action_type == "close_case": | |
| params = {"summary": param1} | |
| try: | |
| result = ui_env.step({"type": action_type, "params": params}) | |
| reward_text = f"**Reward:** {result.reward:+.2f}" | |
| info_text = json.dumps(result.info, indent=2, default=str) | |
| obs = result.observation | |
| status_text = ( | |
| f"Step: {obs.step_number} | Status: {obs.case_status.value} | " | |
| f"Reward: {obs.cumulative_reward:.2f} | Done: {result.done}" | |
| ) | |
| ui_history.append(f"Step {obs.step_number}: {action_type}({param1}) → {result.reward:+.2f}") | |
| history_text = "\n".join(ui_history) | |
| grade_text = "" | |
| if result.done: | |
| scores = ui_env.grade() | |
| grade_lines = [f"**Final Grade: {scores['score']:.4f}**", ""] | |
| for k, v in scores.items(): | |
| if k != "score": | |
| grade_lines.append(f"- {k}: {v}") | |
| grade_text = "\n".join(grade_lines) | |
| return reward_text, status_text, history_text, info_text, grade_text | |
| except Exception as e: | |
| return f"**Error:** {str(e)}", "", "\n".join(ui_history), "", "" | |
| def run_demo(task_name: str) -> str: | |
| """Run a hardcoded optimal sequence and show step-by-step results.""" | |
| task_map = { | |
| "Task 1 — Price Variance (Easy)": "task1_price_variance", | |
| "Task 2 — Duplicate Tax (Medium)": "task2_duplicate_tax", | |
| "Task 3 — Compound Fraud (Hard)": "task3_compound_fraud", | |
| } | |
| task_id = task_map.get(task_name, "task1_price_variance") | |
| # Optimal action sequences for each task | |
| sequences = { | |
| "task1_price_variance": [ | |
| Action.run_check("po_match"), | |
| Action.run_check("tolerance_rule"), | |
| Action.cross_check("unit_price", "invoice", "po"), | |
| Action.run_check("grn_match"), | |
| Action.query_supplier("Why do prices differ from PO?", "email"), | |
| Action.query_internal("procurement", "Did you approve the price increase?"), | |
| Action.apply_rule("tolerance_exception_approval"), | |
| Action.make_decision("approve", "Price increase verbally approved by procurement. PO amendment pending."), | |
| Action.route_to("procurement", "Please raise PO amendment for the price variance."), | |
| Action.close_case("Invoice approved. Procurement confirmed verbal approval. PO amendment requested."), | |
| ], | |
| "task2_duplicate_tax": [ | |
| Action.run_check("duplicate_detection"), | |
| Action.inspect_field("invoice", "invoice_number"), | |
| Action.run_check("tax_calculation_verify"), | |
| Action.cross_check("tax_amount", "invoice", "payment_history"), | |
| Action.query_internal("finance", "Can you confirm the overpayment on INV-2024-819?"), | |
| Action.query_supplier("Please clarify the relationship between INV-2024-891 and INV-2024-819.", "phone"), | |
| Action.apply_rule("partial_approval"), | |
| Action.apply_rule("credit_note_request"), | |
| Action.make_decision("partial_approve", "Duplicate detected. Tax error on original. Approve only 3,240 INR correction."), | |
| Action.route_to("finance", "Process 3,240 INR tax correction entry."), | |
| Action.close_case("Duplicate invoice with tax correction. Partial approval for delta only."), | |
| ], | |
| "task3_compound_fraud": [ | |
| Action.inspect_field("invoice", "bank_account"), | |
| Action.run_check("bank_account_verification"), | |
| Action.run_check("email_domain_verification"), | |
| Action.inspect_field("invoice", "supplier_gstin"), | |
| Action.run_check("gst_verification"), | |
| Action.inspect_field("grn", "items_received"), | |
| Action.run_check("grn_match"), | |
| Action.run_check("price_check"), | |
| Action.query_supplier("Please confirm your bank details and recent invoices.", "phone"), | |
| Action.query_internal("security", "Suspected BEC attack — lookalike domain detected."), | |
| Action.apply_rule("fraud_hold"), | |
| Action.make_decision("reject", "Four fraud signals: bank BEC, GSTIN mismatch, quantity mismatch, price inflation."), | |
| Action.route_to("legal", "Initiate supplier audit and fraud investigation."), | |
| Action.route_to("security", "BEC investigation — lookalike domain techcore-solutions.com."), | |
| Action.close_case("Fraud detected. Invoice rejected. Legal and security notified."), | |
| ], | |
| } | |
| demo_env = InvoiceExceptionEnv(seed=42) | |
| obs = demo_env.reset(task_id) | |
| actions = sequences.get(task_id, []) | |
| lines = [f"# Demo: {task_name}", f"**Flag:** {obs.exception_flag.flag_description}", ""] | |
| for idx, action in enumerate(actions, 1): | |
| try: | |
| result = demo_env.step(action) | |
| action_desc = f"{action.type.value}({json.dumps(action.params)})" | |
| lines.append(f"**Step {idx}:** `{action_desc}`") | |
| lines.append(f" Reward: {result.reward:+.2f} | Cumulative: {result.observation.cumulative_reward:.2f}") | |
| if result.info.get("result"): | |
| detail = result.info["result"].get("detail", result.info["result"].get("value", "")) | |
| if detail: | |
| lines.append(f" → {str(detail)[:120]}") | |
| elif result.info.get("detail"): | |
| lines.append(f" → {str(result.info['detail'])[:120]}") | |
| lines.append("") | |
| if result.done: | |
| break | |
| except Exception as e: | |
| lines.append(f" Error: {e}") | |
| lines.append("") | |
| scores = demo_env.grade() | |
| lines.append("---") | |
| lines.append(f"## Final Score: {scores['score']:.4f}") | |
| for k, v in scores.items(): | |
| if k != "score" and k != "signals_found": | |
| lines.append(f"- {k}: {v}") | |
| if "signals_found" in scores: | |
| lines.append(f"- signals_found: {scores['signals_found']}") | |
| return "\n".join(lines) | |
| def build_gradio_ui() -> gr.Blocks: | |
| """Build the three-tab Gradio interface.""" | |
| with gr.Blocks( | |
| title="Invoice Exception Handler — OpenEnv", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown("# 🧾 Invoice Exception Handler — OpenEnv") | |
| gr.Markdown("An AI agent learning environment for accounts payable exception handling.") | |
| with gr.Tabs(): | |
| # ----- Tab 1: Manual Play ----- | |
| with gr.TabItem("🎮 Manual Play"): | |
| with gr.Row(): | |
| task_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Task 1 — Price Variance (Easy)", | |
| "Task 2 — Duplicate Tax (Medium)", | |
| "Task 3 — Compound Fraud (Hard)", | |
| ], | |
| value="Task 1 — Price Variance (Easy)", | |
| label="Select Task", | |
| ) | |
| reset_btn = gr.Button("🔄 Reset", variant="primary") | |
| flag_display = gr.Markdown(label="Exception Flag") | |
| with gr.Row(): | |
| checks_display = gr.Textbox(label="Available Checks", interactive=False) | |
| rules_display = gr.Textbox(label="Available Rules", interactive=False) | |
| kb_display = gr.Markdown(label="Knowledge Base") | |
| status_display = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Take an Action") | |
| with gr.Row(): | |
| action_type_input = gr.Dropdown( | |
| choices=[at.value for at in ActionType], | |
| value="run_check", | |
| label="Action Type", | |
| ) | |
| param1_input = gr.Textbox(label="Param 1 (check_name / document / field / question / decision / team / summary)") | |
| param2_input = gr.Textbox(label="Param 2 (field / channel / department / reason / notes)") | |
| param3_input = gr.Textbox(label="Param 3 (doc_b, if cross_check)") | |
| action_btn = gr.Button("▶️ Execute Action", variant="primary") | |
| reward_display = gr.Markdown(label="Reward") | |
| action_info = gr.Textbox(label="Action Info (JSON)", lines=4, interactive=False) | |
| history_display = gr.Textbox(label="Action History", lines=8, interactive=False) | |
| grade_display = gr.Markdown(label="Grade (shown when episode ends)") | |
| reset_btn.click( | |
| reset_task, | |
| inputs=[task_dropdown], | |
| outputs=[flag_display, checks_display, rules_display, | |
| kb_display, status_display, history_display, grade_display], | |
| ) | |
| action_btn.click( | |
| execute_action, | |
| inputs=[action_type_input, param1_input, param2_input, param3_input], | |
| outputs=[reward_display, status_display, history_display, | |
| action_info, grade_display], | |
| ) | |
| # ----- Tab 2: Agent Demo ----- | |
| with gr.TabItem("🤖 Agent Demo"): | |
| gr.Markdown("Watch a hardcoded optimal agent solve each task step by step.") | |
| demo_task = gr.Dropdown( | |
| choices=[ | |
| "Task 1 — Price Variance (Easy)", | |
| "Task 2 — Duplicate Tax (Medium)", | |
| "Task 3 — Compound Fraud (Hard)", | |
| ], | |
| value="Task 1 — Price Variance (Easy)", | |
| label="Select Task", | |
| ) | |
| demo_btn = gr.Button("▶️ Run Demo", variant="primary") | |
| demo_output = gr.Markdown() | |
| demo_btn.click(run_demo, inputs=[demo_task], outputs=[demo_output]) | |
| # ----- Tab 3: API Reference ----- | |
| with gr.TabItem("📖 API Reference"): | |
| gr.Markdown(""" | |
| ## Action Types | |
| | Action | Params | Description | | |
| |--------|--------|-------------| | |
| | `inspect_field` | `document, field` | Look at a specific field in a document | | |
| | `cross_check` | `field, doc_a, doc_b` | Compare a field between two documents | | |
| | `run_check` | `check_name` | Run a named validation check | | |
| | `query_supplier` | `question, channel` | Ask the supplier (channel: phone or email) | | |
| | `query_internal` | `department, question` | Ask an internal team | | |
| | `apply_rule` | `rule_id` | Apply a business policy rule | | |
| | `make_decision` | `decision, reason` | approve / reject / hold / partial_approve | | |
| | `route_to` | `team, notes` | Escalate to a team | | |
| | `close_case` | `summary` | Close with an audit trail summary | | |
| ## Reward Ranges | |
| | Event | Reward | | |
| |-------|--------| | |
| | Inspecting a key field | +0.01 to +0.14 | | |
| | Cross-check finds mismatch | +0.12 to +0.15 | | |
| | Running a diagnostic check | +0.08 to +0.18 | | |
| | Correct decision | +0.18 to +0.28 | | |
| | Wrong decision on fraud | −0.35 to −0.40 | | |
| | Contacting supplier via email (fraud) | −0.15 | | |
| | Repeat action | −0.02 to −0.05 | | |
| | SLA breach | −0.10 | | |
| ## HTTP API | |
| ``` | |
| POST /reset — Body: {"task_id": "task1_price_variance"} → EnvironmentState | |
| POST /step — Body: {"type": "run_check", "params": {"check_name": "..."}} → StepResult | |
| GET /state → EnvironmentState | |
| POST /grade → {"score": 0.85, ...} | |
| GET /tasks → ["task1_price_variance", ...] | |
| GET /health → {"status": "ok"} | |
| ``` | |
| ## Grader Sub-Scores | |
| Each task grader returns: | |
| - **score** — overall 0.0–1.0 | |
| - **diagnosis_score** — did the agent find the root cause? | |
| - **investigation_score** — did the agent gather evidence properly? | |
| - **decision_score** — was the decision correct? | |
| - **routing_score** — was the case sent to the right team? | |
| - **closure_score** — was the case closed with a summary? | |
| - **efficiency_score** — bonus for not wasting steps | |
| """) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Main — mount Gradio on FastAPI and serve | |
| # --------------------------------------------------------------------------- | |
| gradio_app = build_gradio_ui() | |
| app = gr.mount_gradio_app(api, gradio_app, path="/") | |
| if __name__ == "__main__": | |
| import signal | |
| import sys | |
| def handle_sigint(sig, frame): | |
| """Graceful shutdown on Ctrl+C.""" | |
| print("\nShutting down gracefully...") | |
| sys.exit(0) | |
| signal.signal(signal.SIGINT, handle_sigint) | |
| try: | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |
| except (KeyboardInterrupt, SystemExit): | |
| pass | |