| import json |
| import os |
| import uuid |
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
| from pathlib import Path |
| from urllib.parse import urlparse |
|
|
| from agent import run_agent_mode, tool_search_kb |
|
|
|
|
| HERE = Path(__file__).resolve().parent |
| WEB_DIR = HERE / "web" |
| OUT_DIR = HERE / "out" |
| KB_DIR = HERE / "kb" |
|
|
|
|
| def _read_body_json(handler: BaseHTTPRequestHandler): |
| length = int(handler.headers.get("Content-Length") or "0") |
| raw = handler.rfile.read(length) if length > 0 else b"" |
| if not raw: |
| return {} |
| try: |
| return json.loads(raw.decode("utf-8")) |
| except Exception: |
| return None |
|
|
|
|
| class Handler(BaseHTTPRequestHandler): |
| server_version = "AgentDemo/1.0" |
|
|
| def log_message(self, format, *args): |
| return |
|
|
| def _send(self, status: int, content_type: str, body: bytes): |
| self.send_response(status) |
| self.send_header("Content-Type", content_type) |
| self.send_header("Content-Length", str(len(body))) |
| self.send_header("Cache-Control", "no-store") |
| self.end_headers() |
| self.wfile.write(body) |
|
|
| def _send_json(self, status: int, obj): |
| body = json.dumps(obj, ensure_ascii=False).encode("utf-8") |
| self._send(status, "application/json; charset=utf-8", body) |
|
|
| def do_GET(self): |
| u = urlparse(self.path) |
| path = u.path or "/" |
|
|
| if path == "/": |
| p = WEB_DIR / "index.html" |
| self._send(200, "text/html; charset=utf-8", p.read_bytes()) |
| return |
|
|
| if path == "/assets/app.js": |
| p = WEB_DIR / "app.js" |
| self._send(200, "application/javascript; charset=utf-8", p.read_bytes()) |
| return |
|
|
| if path == "/assets/app.css": |
| p = WEB_DIR / "app.css" |
| self._send(200, "text/css; charset=utf-8", p.read_bytes()) |
| return |
|
|
| if path == "/api/health": |
| self._send_json(200, {"ok": True, "service": "agent-demo", "version": "1.0"}) |
| return |
|
|
| if path == "/api/kb/list": |
| items = [] |
| for p in sorted(KB_DIR.glob("**/*.md")): |
| try: |
| txt = p.read_text(encoding="utf-8").splitlines() |
| title = "" |
| for ln in txt[:10]: |
| ln = ln.strip() |
| if ln.startswith("#"): |
| title = ln.lstrip("#").strip() |
| break |
| items.append({"path": str(p.relative_to(HERE)), "title": title or p.name}) |
| except Exception: |
| continue |
| self._send_json(200, {"ok": True, "items": items}) |
| return |
|
|
| self._send(404, "text/plain; charset=utf-8", "not found".encode("utf-8")) |
|
|
| def do_POST(self): |
| u = urlparse(self.path) |
| path = u.path or "/" |
|
|
| if path == "/api/kb/search": |
| payload = _read_body_json(self) |
| if payload is None: |
| self._send_json(400, {"ok": False, "error": "请求 JSON 解析失败"}) |
| return |
| query = str(payload.get("query") or "").strip() |
| if not query: |
| self._send_json(400, {"ok": False, "error": "query 不能为空"}) |
| return |
| res = tool_search_kb(query=query, kb_dir=str(KB_DIR), top_k=int(payload.get("top_k") or 6)) |
| if not res.ok: |
| self._send_json(500, {"ok": False, "error": res.output}) |
| return |
| try: |
| items = json.loads(res.output) |
| except Exception: |
| items = [] |
| self._send_json(200, {"ok": True, "items": items}) |
| return |
|
|
| if path != "/api/run": |
| self._send(404, "text/plain; charset=utf-8", "not found".encode("utf-8")) |
| return |
|
|
| payload = _read_body_json(self) |
| if payload is None: |
| self._send_json(400, {"ok": False, "error": "请求 JSON 解析失败"}) |
| return |
|
|
| mode = str(payload.get("mode") or "general").strip() |
| if mode == "lead_followup": |
| lead = payload.get("lead") or {} |
| if not isinstance(lead, dict): |
| self._send_json(400, {"ok": False, "error": "lead 必须是对象"}) |
| return |
| company = str(lead.get("company") or "").strip() |
| pain = str(lead.get("pain_points") or "").strip() |
| if not company and not pain: |
| self._send_json(400, {"ok": False, "error": "至少填写“公司/组织”或“主要诉求/痛点”"}) |
| return |
| else: |
| goal = str(payload.get("goal") or "").strip() |
| if not goal: |
| self._send_json(400, {"ok": False, "error": "goal 不能为空"}) |
| return |
| if len(goal) > 600: |
| self._send_json(400, {"ok": False, "error": "goal 过长(最多 600 字符)"}) |
| return |
|
|
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
| report_name = f"report_{uuid.uuid4().hex[:12]}.md" |
| out_report = str(OUT_DIR / report_name) |
|
|
| try: |
| result = run_agent_mode(mode=mode, payload=payload, kb_dir=str(KB_DIR), out_report=out_report, max_rounds=3) |
| except Exception as e: |
| self._send_json(500, {"ok": False, "error": f"执行失败: {e}"}) |
| return |
|
|
| try: |
| report_text = Path(out_report).read_text(encoding="utf-8") if os.path.exists(out_report) else "" |
| except Exception: |
| report_text = "" |
|
|
| self._send_json( |
| 200, |
| { |
| "ok": True, |
| "mode": result.get("mode") or mode, |
| "goal": result.get("goal"), |
| "final_answer": result.get("final_answer"), |
| "round_logs": result.get("round_logs"), |
| "structured": result.get("structured"), |
| "artifact_path": out_report, |
| "report_text": report_text, |
| }, |
| ) |
|
|
|
|
| def main(): |
| host = "0.0.0.0" |
| if "PORT" in os.environ: |
| port = int(os.environ["PORT"]) |
| httpd = ThreadingHTTPServer((host, port), Handler) |
| print(f"本地演示页已启动: http://127.0.0.1:{port}/") |
| httpd.serve_forever() |
| return |
|
|
| in_space = any(k.startswith("SPACE_") for k in os.environ.keys()) |
| base_port = 7860 if in_space else 8000 |
| last_err = None |
| for port in range(base_port, base_port + 20): |
| try: |
| httpd = ThreadingHTTPServer((host, port), Handler) |
| print(f"本地演示页已启动: http://127.0.0.1:{port}/") |
| httpd.serve_forever() |
| return |
| except OSError as e: |
| last_err = e |
| continue |
| raise SystemExit(f"无法绑定端口(从 {base_port} 起尝试 20 个):{last_err}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|