| | |
| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import mimetypes |
| | import os |
| | import sys |
| | import json |
| | import re |
| | import time |
| | import uuid |
| | import math |
| | import logging |
| | import shutil |
| | from pathlib import Path |
| | from dataclasses import dataclass, field |
| | from typing import Any, Dict, List, Optional, Tuple, Set |
| | from contextlib import asynccontextmanager |
| | from starlette.websockets import WebSocketState, WebSocketDisconnect |
| | try: |
| | import tomllib |
| | except ModuleNotFoundError: |
| | import tomli as tomllib |
| | import traceback |
| |
|
| | try: |
| | from uvicorn.protocols.utils import ClientDisconnected |
| | except Exception: |
| | ClientDisconnected = None |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | import anyio |
| | from fastapi import FastAPI, APIRouter, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect, Request |
| | from fastapi.responses import FileResponse, JSONResponse, Response |
| | from fastapi.staticfiles import StaticFiles |
| |
|
| | from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, AIMessage, ToolMessage |
| |
|
| | |
| | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | SRC_DIR = os.path.join(ROOT_DIR, "src") |
| | if SRC_DIR not in sys.path: |
| | sys.path.insert(0, SRC_DIR) |
| |
|
| | from open_storyline.agent import build_agent, ClientContext |
| | from open_storyline.utils.prompts import get_prompt |
| | from open_storyline.utils.media_handler import scan_media_dir |
| | from open_storyline.config import load_settings, default_config_path |
| | from open_storyline.config import Settings |
| | from open_storyline.storage.agent_memory import ArtifactStore |
| | from open_storyline.mcp.hooks.node_interceptors import ToolInterceptor |
| | from open_storyline.mcp.hooks.chat_middleware import set_mcp_log_sink, reset_mcp_log_sink |
| |
|
| | WEB_DIR = os.path.join(ROOT_DIR, "web") |
| | STATIC_DIR = os.path.join(WEB_DIR, "static") |
| | INDEX_HTML = os.path.join(WEB_DIR, "index.html") |
| | NODE_MAP_HTML = os.path.join(WEB_DIR, "node_map/node_map.html") |
| | NODE_MAP_DIR = os.path.join(WEB_DIR, "node_map") |
| |
|
| | SERVER_CACHE_DIR = os.path.join(ROOT_DIR, '.storyline' , ".server_cache") |
| |
|
| | CHUNK_SIZE = 1024 * 1024 |
| |
|
| | |
| | USE_SESSION_SUBDIR = True |
| |
|
| | CUSTOM_MODEL_KEY = "__custom__" |
| |
|
| | |
| | DEFAULT_LLM_API_KEY = os.getenv("DEEPSEEK_API_KEY") |
| | DEFAULT_LLM_API_URL = os.getenv("DEEPSEEK_API_URL") |
| | DEFAULT_LLM_API_NAME = os.getenv("DEEPSEEK_API_NAME", "deepseek-chat") |
| | DEFAULT_VLM_API_KEY = os.getenv("GLM_V4_6_API_KEY") |
| | DEFAULT_VLM_API_URL = os.getenv("GLM_V4_6_API_URL") |
| | DEFAULT_VLM_API_NAME = os.getenv("GLM_V4_6_API_NAME", "qwen3-vl-8b-instruct") |
| | print("DEEPSEEK_API_KEY exists:", bool(os.getenv("DEEPSEEK_API_KEY"))) |
| | print("QWEN3_VL_8B_API_KEY exists:", bool(os.getenv("QWEN3_VL_8B_API_KEY"))) |
| | print("DEEPSEEK_API_URL:", repr(os.getenv("DEEPSEEK_API_URL"))) |
| | print("QWEN3_VL_8B_API_URL:", repr(os.getenv("QWEN3_VL_8B_API_URL"))) |
| |
|
| | def debug_traceback_print(cfg: Settings): |
| | if cfg.developer.developer_mode: |
| | traceback.print_exc() |
| |
|
| | def _s(x: Any) -> str: |
| | return str(x or "").strip() |
| |
|
| | def _norm_url(u: Any) -> str: |
| | u = _s(u) |
| | return u.rstrip("/") if u else "" |
| |
|
| | def _env_fallback_for_model(model_name: str) -> Tuple[str, str]: |
| | """ |
| | - deepseek* -> DEEPSEEK_API_URL / DEEPSEEK_API_KEY |
| | - qwen3* -> QWEN3_VL_8B_API_URL / QWEN3_VL_8B_API_KEY |
| | """ |
| | m = _s(model_name).lower() |
| | if "deepseek" in m: |
| | return (_s(os.getenv("DEEPSEEK_API_URL")), _s(os.getenv("DEEPSEEK_API_KEY"))) |
| | if m.startswith("qwen3-vl-8b-instruct") or "qwen3-vl-8b-instruct" in m: |
| | return (_s(os.getenv("QWEN3_VL_8B_API_URL")), _s(os.getenv("QWEN3_VL_8B_API_KEY"))) |
| | return ("", "") |
| |
|
| | def _resolve_default_model_override(cfg: Settings, model_name: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: |
| | """ |
| | 1. get config from [developer.chat_models_config."<model_name>"] |
| | 2. rollback to env |
| | """ |
| | model_name = _s(model_name) |
| | if not model_name: |
| | return None, "default model name is empty" |
| |
|
| | model_cfg: Dict[str, Any] = {} |
| | try: |
| | model_cfg = (cfg.developer.chat_models_config.get(model_name) or {}) if getattr(cfg, "developer", None) else {} |
| | except Exception: |
| | model_cfg = {} |
| |
|
| | if not isinstance(model_cfg, dict): |
| | model_cfg = {} |
| |
|
| | base_url = _norm_url(model_cfg.get("base_url")) |
| | api_key = _s(model_cfg.get("api_key")) |
| |
|
| | if not base_url or not api_key: |
| | env_url, env_key = _env_fallback_for_model(model_name) |
| | if not base_url: |
| | base_url = _norm_url(env_url) |
| | if not api_key: |
| | api_key = _s(env_key) |
| |
|
| | override: Dict[str, Any] = {"model": model_name} |
| | if base_url: |
| | override["base_url"] = base_url |
| | if api_key: |
| | override["api_key"] = api_key |
| |
|
| | for k in ("timeout", "temperature", "max_retries", "top_p", "max_tokens"): |
| | if k in model_cfg and model_cfg.get(k) not in (None, ""): |
| | override[k] = model_cfg.get(k) |
| |
|
| | if not override.get("base_url") or not override.get("api_key"): |
| | return None, ( |
| | f"cannot find base_url/api_key of default model: {model_name}. " |
| | f"please fill in base_url/api_key of [developer.chat_models_config.\"{model_name}\" in config.toml]" |
| | f"or set environment variables(DEEPSEEK_API_URL/DEEPSEEK_API_KEY / QWEN3_VL_8B_API_URL/QWEN3_VL_8B_API_KEY)。" |
| | ) |
| |
|
| | return override, None |
| |
|
| | def _stable_dict_key(d: Optional[Dict[str, Any]]) -> str: |
| | try: |
| | return json.dumps(d or {}, sort_keys=True, ensure_ascii=False) |
| | except Exception: |
| | return str(d or {}) |
| |
|
| | def _parse_service_config(service_cfg: Any) -> Tuple[ |
| | Optional[Dict[str, Any]], |
| | Optional[Dict[str, Any]], |
| | Dict[str, Any], |
| | Dict[str, Any], |
| | Optional[str]]: |
| | """ |
| | 返回 (custom_llm, custom_vlm, tts_cfg, pexels, err) |
| | - custom_llm/custom_vlm: {"model","base_url","api_key"} 或 None(允许只传 llm 或只传 vlm) |
| | - tts_cfg: dict(可能为空) |
| | """ |
| | if not isinstance(service_cfg, dict): |
| | return None, None, {}, {}, None |
| |
|
| | |
| | custom_llm = None |
| | custom_vlm = None |
| | custom_models = service_cfg.get("custom_models") |
| |
|
| | if custom_models is not None: |
| | if not isinstance(custom_models, dict): |
| | return None, None, {}, {}, "service_config.custom_models 必须是对象" |
| |
|
| | def _pick(m: Any, label: str) -> Tuple[Optional[Dict[str, str]], Optional[str]]: |
| | if m is None: |
| | return None, None |
| | if not isinstance(m, dict): |
| | return None, f"service_config.custom_models.{label} 必须是对象" |
| |
|
| | model = _s(m.get("model")) |
| | base_url = _norm_url(m.get("base_url")) |
| | api_key = _s(m.get("api_key")) |
| |
|
| | if not (model and base_url and api_key): |
| | return None, f"自定义 {label.upper()} 配置不完整:请填写 model/base_url/api_key" |
| | if not (base_url.startswith("http://") or base_url.startswith("https://")): |
| | return None, f"自定义 {label.upper()} 的 base_url 必须以 http(s) 开头" |
| | return {"model": model, "base_url": base_url, "api_key": api_key}, None |
| |
|
| | custom_llm, err1 = _pick(custom_models.get("llm"), "llm") |
| | if err1: |
| | return None, None, {}, {}, err1 |
| |
|
| | custom_vlm, err2 = _pick(custom_models.get("vlm"), "vlm") |
| | if err2: |
| | return None, None, {}, {}, err2 |
| |
|
| | |
| | tts_cfg: Dict[str, Any] = {} |
| | tts = service_cfg.get("tts") |
| | if isinstance(tts, dict): |
| | provider = (tts.get("provider") or "").strip().lower() |
| | if provider: |
| | provider_block = tts.get(provider) |
| | tts_cfg = {"provider": provider, provider: provider_block} |
| | |
| | |
| | pexels_cfg: Dict[str, Any] = {} |
| | search_media = service_cfg.get("search_media") |
| | if isinstance(search_media, dict): |
| | |
| | |
| | |
| | p = search_media.get("pexels") or search_media.get("pexels") |
| | if isinstance(p, dict): |
| | mode = _s(p.get("mode")).lower() |
| | if mode not in ("default", "custom"): |
| | mode = "default" |
| | api_key = _s(p.get("api_key") or p.get("pexels_api_key") or p.get("pexels_api_key")) |
| | pexels_cfg = {"mode": mode, "api_key": api_key} |
| | else: |
| | mode = _s(search_media.get("mode") or search_media.get("pexels_mode") or search_media.get("pexels_mode")).lower() |
| | if mode not in ("default", "custom"): |
| | mode = "default" |
| | api_key = _s(search_media.get("pexels_api_key") or search_media.get("pexels_api_key")) |
| | pexels_cfg = {"mode": mode, "api_key": api_key} |
| |
|
| | return custom_llm, custom_vlm, tts_cfg, pexels_cfg, None |
| |
|
| | def is_developer_mode(cfg: Settings) -> bool: |
| | try: |
| | return bool(cfg.developer.developer_mode) |
| | except Exception: |
| | return False |
| |
|
| | def _abs(p: str) -> str: |
| | return os.path.abspath(os.path.expanduser(p)) |
| |
|
| |
|
| | def resolve_media_dir(cfg_media_dir: str, session_id: str) -> str: |
| | root = _abs(cfg_media_dir).rstrip("/\\") |
| | if not USE_SESSION_SUBDIR: |
| | return root |
| | project_dir = os.path.dirname(root) |
| | leaf = os.path.basename(root) |
| | return os.path.join(project_dir, session_id, leaf) |
| |
|
| |
|
| | def sanitize_filename(name: str) -> str: |
| | name = os.path.basename(name or "") |
| | name = name.replace("\x00", "") |
| | return name or "unnamed" |
| |
|
| |
|
| | def detect_media_kind(filename: str) -> str: |
| | ext = os.path.splitext(filename)[1].lower() |
| | if ext in {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"}: |
| | return "image" |
| | if ext in {".mp4", ".mov", ".avi", ".mkv", ".webm"}: |
| | return "video" |
| | return "unknown" |
| |
|
| | _MEDIA_RE = re.compile(r"^media_(\d+)", re.IGNORECASE) |
| |
|
| | def make_media_store_filename(seq: int, ext: str) -> str: |
| | ext = (ext or "").lower() |
| | if ext and not ext.startswith("."): |
| | ext = "." + ext |
| | return f"{MEDIA_PREFIX}{seq:0{MEDIA_SEQ_WIDTH}d}{ext}" |
| |
|
| | def parse_media_seq(filename: str) -> Optional[int]: |
| | m = _MEDIA_RE.match(os.path.basename(filename or "")) |
| | if not m: |
| | return None |
| | try: |
| | return int(m.group(1)) |
| | except Exception: |
| | return None |
| |
|
| | def safe_save_path_no_overwrite(media_dir: str, filename: str) -> str: |
| | filename = sanitize_filename(filename) |
| | stem, ext = os.path.splitext(filename) |
| | path = os.path.join(media_dir, filename) |
| | if not os.path.exists(path): |
| | return path |
| | i = 2 |
| | while True: |
| | p2 = os.path.join(media_dir, f"{stem} ({i}){ext}") |
| | if not os.path.exists(p2): |
| | return p2 |
| | i += 1 |
| |
|
| |
|
| | def ensure_thumbs_dir(media_dir: str) -> str: |
| | d = os.path.join(media_dir, ".thumbs") |
| | os.makedirs(d, exist_ok=True) |
| | return d |
| |
|
| | def ensure_uploads_dir(media_dir: str) -> str: |
| | d = os.path.join(media_dir, ".uploads") |
| | os.makedirs(d, exist_ok=True) |
| | return d |
| |
|
| | def guess_media_type(path: str) -> str: |
| | mt, _ = mimetypes.guess_type(path) |
| | return mt or "application/octet-stream" |
| |
|
| |
|
| | def _is_under_dir(path: str, root: str) -> bool: |
| | try: |
| | path = os.path.abspath(path) |
| | root = os.path.abspath(root) |
| | return os.path.commonpath([path, root]) == root |
| | except Exception: |
| | return False |
| |
|
| |
|
| | def video_placeholder_svg_bytes() -> bytes: |
| | svg = """<svg xmlns="http://www.w3.org/2000/svg" width="320" height="320" viewBox="0 0 320 320"> |
| | <defs> |
| | <linearGradient id="g" x1="0" x2="1" y1="0" y2="1"> |
| | <stop stop-color="#f2f2f2" offset="0"/> |
| | <stop stop-color="#e6e6e6" offset="1"/> |
| | </linearGradient> |
| | </defs> |
| | <rect x="0" y="0" width="320" height="320" fill="url(#g)"/> |
| | <rect x="22" y="22" width="276" height="276" rx="22" fill="rgba(0,0,0,0.06)"/> |
| | <polygon points="140,120 140,200 210,160" fill="rgba(0,0,0,0.55)"/> |
| | </svg>""" |
| | return svg.encode("utf-8") |
| |
|
| |
|
| | def make_image_thumbnail_sync(src_path: str, dst_path: str, max_size: Tuple[int, int] = (320, 320)) -> bool: |
| | try: |
| | from PIL import Image |
| | img = Image.open(src_path).convert("RGB") |
| | img.thumbnail(max_size) |
| | img.save(dst_path, format="JPEG", quality=85) |
| | return True |
| | except Exception: |
| | return False |
| |
|
| | async def make_video_thumbnail_async( |
| | src_video: str, |
| | dst_path: str, |
| | *, |
| | max_size: Tuple[int, int] = (320, 320), |
| | seek_sec: float = 0.5, |
| | timeout_sec: float = 20.0, |
| | ) -> bool: |
| | ffmpeg = os.environ.get("FFMPEG_BIN") or shutil.which("ffmpeg") |
| | if not ffmpeg: |
| | logger.warning("ffmpeg not found (PATH/FFMPEG_BIN). skip video thumbnail. src=%s", src_video) |
| | return False |
| |
|
| | src_video = os.path.abspath(src_video) |
| | dst_path = os.path.abspath(dst_path) |
| | os.makedirs(os.path.dirname(dst_path), exist_ok=True) |
| |
|
| | tmp_path = dst_path + ".tmp.jpg" |
| |
|
| | vf = ( |
| | f"scale={max_size[0]}:{max_size[1]}:force_original_aspect_ratio=decrease" |
| | f",pad={max_size[0]}:{max_size[1]}:(ow-iw)/2:(oh-ih)/2" |
| | ) |
| |
|
| | async def _run(args: list[str]) -> tuple[bool, str]: |
| | proc = await asyncio.create_subprocess_exec( |
| | *args, |
| | stdout=asyncio.subprocess.DEVNULL, |
| | stderr=asyncio.subprocess.PIPE, |
| | ) |
| | try: |
| | _, err = await asyncio.wait_for(proc.communicate(), timeout=timeout_sec) |
| | except asyncio.TimeoutError: |
| | try: |
| | proc.kill() |
| | except Exception: |
| | pass |
| | await proc.wait() |
| | return False, f"timeout after {timeout_sec}s" |
| | err_text = (err or b"").decode("utf-8", "ignore").strip() |
| | return (proc.returncode == 0), err_text |
| |
|
| | |
| | |
| | common_tail = [ |
| | "-an", |
| | "-frames:v", "1", |
| | "-vf", vf, |
| | "-vcodec", "mjpeg", |
| | "-q:v", "3", |
| | "-f", "image2", |
| | tmp_path, |
| | ] |
| |
|
| | attempts = [ |
| | |
| | [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-ss", f"{seek_sec}", "-i", src_video] + common_tail, |
| | |
| | [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-i", src_video, "-ss", f"{seek_sec}"] + common_tail, |
| | |
| | [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-ss", "1.0", "-i", src_video] + common_tail, |
| | ] |
| |
|
| | last_err: Optional[str] = None |
| | try: |
| | for args in attempts: |
| | ok, err = await _run(args) |
| | if ok and os.path.exists(tmp_path) and os.path.getsize(tmp_path) > 0: |
| | os.replace(tmp_path, dst_path) |
| | return True |
| | last_err = err or last_err |
| | |
| | try: |
| | if os.path.exists(tmp_path): |
| | os.remove(tmp_path) |
| | except Exception: |
| | pass |
| |
|
| | logger.warning("ffmpeg thumbnail failed. src=%s dst=%s err=%s", src_video, dst_path, last_err) |
| | return False |
| | finally: |
| | try: |
| | if os.path.exists(tmp_path): |
| | os.remove(tmp_path) |
| | except Exception: |
| | pass |
| |
|
| | def _env_int(name: str, default: int) -> int: |
| | try: |
| | return int(os.environ.get(name, str(default))) |
| | except Exception: |
| | return default |
| |
|
| | def _env_float(name: str, default: float) -> float: |
| | try: |
| | return float(os.environ.get(name, str(default))) |
| | except Exception: |
| | return float(default) |
| |
|
| | def _rpm_to_rps(rpm: float) -> float: |
| | return float(rpm) / 60.0 |
| |
|
| |
|
| | |
| | RATE_LIMIT_TRUST_PROXY_HEADERS = os.environ.get("RATE_LIMIT_TRUST_PROXY_HEADERS", "0") == "1" |
| |
|
| | @dataclass |
| | class _RateBucket: |
| | tokens: float |
| | last_ts: float |
| | last_seen: float |
| |
|
| | class TokenBucketRateLimiter: |
| | """ |
| | 内存令牌桶 + 防爆内存: |
| | - max_buckets: 限制内部桶表最大条目数(防止海量 IP 导致字典膨胀) |
| | - evict_batch: 超过上限后每次驱逐多少条(按插入顺序驱逐最早创建的桶) |
| | """ |
| | def __init__( |
| | self, |
| | ttl_sec: int = 900, |
| | cleanup_interval_sec: int = 60, |
| | *, |
| | max_buckets: int = 100000, |
| | evict_batch: int = 2000, |
| | ): |
| | self.ttl_sec = int(ttl_sec) |
| | self.cleanup_interval_sec = int(cleanup_interval_sec) |
| | self.max_buckets = int(max(1, max_buckets)) |
| | self.evict_batch = int(max(1, evict_batch)) |
| |
|
| | self._buckets: Dict[str, _RateBucket] = {} |
| | self._lock = asyncio.Lock() |
| | self._last_cleanup = time.monotonic() |
| |
|
| | async def allow( |
| | self, |
| | key: str, |
| | *, |
| | capacity: float, |
| | refill_rate: float, |
| | cost: float = 1.0, |
| | ) -> Tuple[bool, float, float]: |
| | """ |
| | 返回: (allowed, retry_after_sec, remaining_tokens) |
| | """ |
| | now = time.monotonic() |
| | capacity = float(max(0.0, capacity)) |
| | refill_rate = float(max(0.0, refill_rate)) |
| | cost = float(max(0.0, cost)) |
| |
|
| | async with self._lock: |
| | b = self._buckets.get(key) |
| |
|
| | if b is None: |
| | |
| | if now - self._last_cleanup > self.cleanup_interval_sec: |
| | self._cleanup_locked(now) |
| | self._last_cleanup = now |
| |
|
| | |
| | if len(self._buckets) >= self.max_buckets: |
| | self._cleanup_locked(now) |
| |
|
| | if len(self._buckets) >= self.max_buckets: |
| | self._evict_locked() |
| |
|
| | if len(self._buckets) >= self.max_buckets: |
| | |
| | |
| | return False, 1.0, 0.0 |
| |
|
| | b = _RateBucket(tokens=capacity, last_ts=now, last_seen=now) |
| | self._buckets[key] = b |
| | else: |
| | b.last_seen = now |
| |
|
| | |
| | elapsed = max(0.0, now - b.last_ts) |
| | if refill_rate > 0: |
| | b.tokens = min(capacity, b.tokens + elapsed * refill_rate) |
| | else: |
| | b.tokens = min(capacity, b.tokens) |
| | b.last_ts = now |
| |
|
| | if b.tokens >= cost: |
| | b.tokens -= cost |
| | return True, 0.0, float(max(0.0, b.tokens)) |
| |
|
| | |
| | if refill_rate <= 0: |
| | retry_after = float(self.ttl_sec) |
| | else: |
| | need = cost - b.tokens |
| | retry_after = need / refill_rate |
| | return False, float(retry_after), float(max(0.0, b.tokens)) |
| |
|
| | def _cleanup_locked(self, now: float) -> None: |
| | ttl = float(self.ttl_sec) |
| | dead = [k for k, b in self._buckets.items() if (now - b.last_seen) > ttl] |
| | for k in dead: |
| | self._buckets.pop(k, None) |
| |
|
| | def _evict_locked(self) -> None: |
| | |
| | n = min(self.evict_batch, len(self._buckets)) |
| | for _ in range(n): |
| | try: |
| | k = next(iter(self._buckets)) |
| | except StopIteration: |
| | break |
| | self._buckets.pop(k, None) |
| |
|
| | def _headers_to_dict(scope_headers: List[Tuple[bytes, bytes]]) -> Dict[str, str]: |
| | d: Dict[str, str] = {} |
| | for k, v in scope_headers or []: |
| | try: |
| | dk = k.decode("latin1").lower() |
| | dv = v.decode("latin1") |
| | except Exception: |
| | continue |
| | d[dk] = dv |
| | return d |
| |
|
| | def _client_ip_from_http_scope(scope: dict, trust_proxy_headers: bool) -> str: |
| | headers = _headers_to_dict(scope.get("headers") or []) |
| | if trust_proxy_headers: |
| | xff = headers.get("x-forwarded-for") |
| | if xff: |
| | |
| | return xff.split(",")[0].strip() or "unknown" |
| | xri = headers.get("x-real-ip") |
| | if xri: |
| | return xri.strip() or "unknown" |
| |
|
| | client = scope.get("client") |
| | if client and isinstance(client, (list, tuple)) and len(client) >= 1: |
| | return str(client[0] or "unknown") |
| | return "unknown" |
| |
|
| | def _client_ip_from_ws(ws: WebSocket, trust_proxy_headers: bool) -> str: |
| | try: |
| | if trust_proxy_headers: |
| | xff = ws.headers.get("x-forwarded-for") |
| | if xff: |
| | return xff.split(",")[0].strip() or "unknown" |
| | xri = ws.headers.get("x-real-ip") |
| | if xri: |
| | return xri.strip() or "unknown" |
| | except Exception: |
| | pass |
| |
|
| | try: |
| | if ws.client: |
| | return str(ws.client.host or "unknown") |
| | except Exception: |
| | pass |
| |
|
| | return "unknown" |
| |
|
| | |
| | UPLOAD_RESUMABLE_CHUNK_BYTES = _env_int("UPLOAD_RESUMABLE_CHUNK_BYTES", 8 * 1024 * 1024) |
| |
|
| | |
| | RESUMABLE_UPLOAD_TTL_SEC = _env_int("RESUMABLE_UPLOAD_TTL_SEC", 3600) |
| |
|
| | MEDIA_SEQ_WIDTH = 4 |
| | MEDIA_PREFIX = "media_" |
| |
|
| |
|
| | |
| |
|
| | |
| | HTTP_GLOBAL_RPM = _env_int("RATE_LIMIT_HTTP_GLOBAL_RPM", 3000) |
| | HTTP_GLOBAL_BURST = _env_int("RATE_LIMIT_HTTP_GLOBAL_BURST", 600) |
| |
|
| | |
| | HTTP_CREATE_SESSION_RPM = _env_int("RATE_LIMIT_CREATE_SESSION_RPM", 3000) |
| | HTTP_CREATE_SESSION_BURST = _env_int("RATE_LIMIT_CREATE_SESSION_BURST", 50) |
| |
|
| | |
| | HTTP_UPLOAD_MEDIA_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_RPM", 12000) |
| | HTTP_UPLOAD_MEDIA_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_BURST", 300) |
| |
|
| | |
| | UPLOAD_COST_BYTES = _env_int("RATE_LIMIT_UPLOAD_COST_BYTES", 10 * 1024 * 1024) |
| |
|
| | |
| | MAX_UPLOAD_FILES_PER_REQUEST = _env_int("MAX_UPLOAD_FILES_PER_REQUEST", 30) |
| | MAX_MEDIA_PER_SESSION = _env_int("MAX_MEDIA_PER_SESSION", 30) |
| | MAX_PENDING_MEDIA_PER_SESSION = _env_int("MAX_PENDING_MEDIA_PER_SESSION", 30) |
| |
|
| | HTTP_UPLOAD_MEDIA_COUNT_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_RPM", 50000) |
| | HTTP_UPLOAD_MEDIA_COUNT_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_BURST", 1000) |
| |
|
| | |
| | HTTP_MEDIA_GET_RPM = _env_int("RATE_LIMIT_MEDIA_GET_RPM", 2400) |
| | HTTP_MEDIA_GET_BURST = _env_int("RATE_LIMIT_MEDIA_GET_BURST", 60) |
| |
|
| | |
| | HTTP_CLEAR_RPM = _env_int("RATE_LIMIT_CLEAR_SESSION_RPM", 3000) |
| | HTTP_CLEAR_BURST = _env_int("RATE_LIMIT_CLEAR_SESSION_BURST", 50) |
| |
|
| | |
| | HTTP_API_RPM = _env_int("RATE_LIMIT_API_RPM", 2400) |
| | HTTP_API_BURST = _env_int("RATE_LIMIT_API_BURST", 120) |
| |
|
| | |
| | WS_CONNECT_RPM = _env_int("RATE_LIMIT_WS_CONNECT_RPM", 600) |
| | WS_CONNECT_BURST = _env_int("RATE_LIMIT_WS_CONNECT_BURST", 50) |
| |
|
| | |
| | WS_CHAT_SEND_RPM = _env_int("RATE_LIMIT_WS_CHAT_SEND_RPM", 300) |
| | WS_CHAT_SEND_BURST = _env_int("RATE_LIMIT_WS_CHAT_SEND_BURST", 20) |
| |
|
| | |
| | HTTP_ALL_RPM = _env_int("RATE_LIMIT_HTTP_ALL_RPM", 1200) |
| | HTTP_ALL_BURST = _env_int("RATE_LIMIT_HTTP_ALL_BURST", 200) |
| |
|
| | CREATE_SESSION_ALL_RPM = _env_int("RATE_LIMIT_CREATE_SESSION_ALL_RPM", 120) |
| | CREATE_SESSION_ALL_BURST = _env_int("RATE_LIMIT_CREATE_SESSION_ALL_BURST", 20) |
| |
|
| | UPLOAD_MEDIA_ALL_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_ALL_RPM", 6000) |
| | UPLOAD_MEDIA_ALL_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_ALL_BURST", 2000) |
| |
|
| | |
| | UPLOAD_MEDIA_COUNT_ALL_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_ALL_RPM", UPLOAD_MEDIA_ALL_RPM) |
| | UPLOAD_MEDIA_COUNT_ALL_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_ALL_BURST", UPLOAD_MEDIA_ALL_BURST) |
| |
|
| | MEDIA_GET_ALL_RPM = _env_int("RATE_LIMIT_MEDIA_GET_ALL_RPM", 600) |
| | MEDIA_GET_ALL_BURST = _env_int("RATE_LIMIT_MEDIA_GET_ALL_BURST", 120) |
| |
|
| | WS_CONNECT_ALL_RPM = _env_int("RATE_LIMIT_WS_CONNECT_ALL_RPM", 60000) |
| | WS_CONNECT_ALL_BURST = _env_int("RATE_LIMIT_WS_CONNECT_ALL_BURST", 2000) |
| |
|
| | WS_CHAT_SEND_ALL_RPM = _env_int("RATE_LIMIT_WS_CHAT_SEND_ALL_RPM", 500) |
| | WS_CHAT_SEND_ALL_BURST = _env_int("RATE_LIMIT_WS_CHAT_SEND_ALL_BURST", 30) |
| |
|
| | |
| | WS_MAX_CONNECTIONS = _env_int("RATE_LIMIT_WS_MAX_CONNECTIONS", 500) |
| | CHAT_MAX_CONCURRENCY = _env_int("RATE_LIMIT_CHAT_MAX_CONCURRENCY", 80) |
| | UPLOAD_MAX_CONCURRENCY = _env_int("RATE_LIMIT_UPLOAD_MAX_CONCURRENCY", 100) |
| |
|
| | WS_CONN_SEM = asyncio.Semaphore(WS_MAX_CONNECTIONS) |
| | CHAT_TURN_SEM = asyncio.Semaphore(CHAT_MAX_CONCURRENCY) |
| | UPLOAD_SEM = asyncio.Semaphore(UPLOAD_MAX_CONCURRENCY) |
| |
|
| | def _global_http_rule_limit(rule_name: str) -> Optional[Tuple[int, int]]: |
| | if rule_name == "create_session": |
| | return CREATE_SESSION_ALL_BURST, CREATE_SESSION_ALL_RPM |
| | if rule_name == "upload_media": |
| | return UPLOAD_MEDIA_ALL_BURST, UPLOAD_MEDIA_ALL_RPM |
| | if rule_name == "media_get": |
| | return MEDIA_GET_ALL_BURST, MEDIA_GET_ALL_RPM |
| | return None |
| |
|
| |
|
| | def _get_content_length(scope: dict) -> Optional[int]: |
| | try: |
| | headers = _headers_to_dict(scope.get("headers") or []) |
| | v = headers.get("content-length") |
| | if v is None: |
| | return None |
| | n = int(v) |
| | if n < 0: |
| | return None |
| | return n |
| | except Exception: |
| | return None |
| |
|
| | def _match_http_rule(method: str, path: str) -> Tuple[str, int, int, float]: |
| | """ |
| | 返回 (rule_name, burst, rpm, cost) |
| | cost 默认为 1;上传接口会按 content-length 动态计算 cost(在 middleware 内处理)。 |
| | """ |
| | method = (method or "").upper() |
| | path = path or "" |
| |
|
| | |
| | if method == "POST" and path == "/api/sessions": |
| | return ("create_session", HTTP_CREATE_SESSION_BURST, HTTP_CREATE_SESSION_RPM, 1.0) |
| |
|
| | |
| | if method == "POST" and path.startswith("/api/sessions/"): |
| | if path.endswith("/media") or path.endswith("/media/init"): |
| | return ("upload_media", HTTP_UPLOAD_MEDIA_BURST, HTTP_UPLOAD_MEDIA_RPM, 1.0) |
| | if "/media/" in path and (path.endswith("/chunk") or path.endswith("/complete") or path.endswith("/cancel")): |
| | return ("upload_media", HTTP_UPLOAD_MEDIA_BURST, HTTP_UPLOAD_MEDIA_RPM, 1.0) |
| |
|
| | if method == "GET" and path.startswith("/api/sessions/") and (path.endswith("/thumb") or path.endswith("/file")): |
| | return ("media_get", HTTP_MEDIA_GET_BURST, HTTP_MEDIA_GET_RPM, 1.0) |
| |
|
| | if method == "POST" and path.startswith("/api/sessions/") and path.endswith("/clear"): |
| | return ("clear_session", HTTP_CLEAR_BURST, HTTP_CLEAR_RPM, 1.0) |
| |
|
| | |
| | if path.startswith("/api/"): |
| | return ("api_general", HTTP_API_BURST, HTTP_API_RPM, 1.0) |
| |
|
| | |
| | return ("", 0, 0, 1.0) |
| |
|
| | class HttpRateLimitMiddleware: |
| | """ |
| | ASGI middleware:对 HTTP 请求做限流(WebSocket 不在这里处理)。 |
| | """ |
| | def __init__(self, app: Any, limiter: TokenBucketRateLimiter, trust_proxy_headers: bool = False): |
| | self.app = app |
| | self.limiter = limiter |
| | self.trust_proxy_headers = bool(trust_proxy_headers) |
| |
|
| | async def __call__(self, scope: dict, receive: Any, send: Any): |
| | if scope.get("type") != "http": |
| | return await self.app(scope, receive, send) |
| |
|
| | method = scope.get("method", "GET") |
| | path = scope.get("path", "/") |
| | ip = _client_ip_from_http_scope(scope, self.trust_proxy_headers) |
| |
|
| | |
| | ok, retry_after, _ = await self.limiter.allow( |
| | key="http:all", |
| | capacity=float(HTTP_ALL_BURST), |
| | refill_rate=_rpm_to_rps(float(HTTP_ALL_RPM)), |
| | cost=1.0, |
| | ) |
| | if not ok: |
| | return await self._reject(send, retry_after) |
| |
|
| | |
| | ok, retry_after, _ = await self.limiter.allow( |
| | key=f"http:global:{ip}", |
| | capacity=float(HTTP_GLOBAL_BURST), |
| | refill_rate=_rpm_to_rps(float(HTTP_GLOBAL_RPM)), |
| | cost=1.0, |
| | ) |
| | if not ok: |
| | return await self._reject(send, retry_after) |
| |
|
| | |
| | rule_name, burst, rpm, cost = _match_http_rule(method, path) |
| |
|
| | |
| | if rule_name == "upload_media": |
| | cl = _get_content_length(scope) |
| | if cl and cl > 0 and UPLOAD_COST_BYTES > 0: |
| | cost = max(1.0, float(math.ceil(cl / float(UPLOAD_COST_BYTES)))) |
| |
|
| | if rule_name: |
| | |
| | g = _global_http_rule_limit(rule_name) |
| | if g: |
| | g_burst, g_rpm = g |
| | okg, rag, _ = await self.limiter.allow( |
| | key=f"http:{rule_name}:all", |
| | capacity=float(g_burst), |
| | refill_rate=_rpm_to_rps(float(g_rpm)), |
| | cost=float(cost), |
| | ) |
| | if not okg: |
| | return await self._reject(send, rag) |
| |
|
| | |
| | ok2, retry_after2, _ = await self.limiter.allow( |
| | key=f"http:{rule_name}:{ip}", |
| | capacity=float(burst), |
| | refill_rate=_rpm_to_rps(float(rpm)), |
| | cost=float(cost), |
| | ) |
| | if not ok2: |
| | return await self._reject(send, retry_after2) |
| |
|
| | return await self.app(scope, receive, send) |
| |
|
| |
|
| | async def _reject(self, send: Any, retry_after: float): |
| | ra = int(math.ceil(float(retry_after or 0.0))) |
| | body = json.dumps( |
| | {"detail": "Too Many Requests", "retry_after": ra}, |
| | ensure_ascii=False |
| | ).encode("utf-8") |
| |
|
| | headers = [ |
| | (b"content-type", b"application/json; charset=utf-8"), |
| | (b"retry-after", str(ra).encode("ascii")), |
| | ] |
| |
|
| | await send({"type": "http.response.start", "status": 429, "headers": headers}) |
| | await send({"type": "http.response.body", "body": body, "more_body": False}) |
| |
|
| | RATE_LIMITER = TokenBucketRateLimiter( |
| | ttl_sec=_env_int("RATE_LIMIT_TTL_SEC", 900), |
| | cleanup_interval_sec=_env_int("RATE_LIMIT_CLEANUP_INTERVAL_SEC", 60), |
| | max_buckets=_env_int("RATE_LIMIT_MAX_BUCKETS", 100000), |
| | evict_batch=_env_int("RATE_LIMIT_EVICT_BATCH", 2000), |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class MediaMeta: |
| | id: str |
| | name: str |
| | kind: str |
| | path: str |
| | thumb_path: Optional[str] |
| | ts: float |
| |
|
| | @dataclass |
| | class ResumableUpload: |
| | upload_id: str |
| | filename: str |
| | store_filename: str |
| | size: int |
| | chunk_size: int |
| | total_chunks: int |
| | tmp_path: str |
| | kind: str |
| | created_ts: float |
| | last_ts: float |
| | received: Set[int] = field(default_factory=set) |
| | closed: bool = False |
| | lock: asyncio.Lock = field(default_factory=asyncio.Lock) |
| |
|
| | class MediaStore: |
| | """ |
| | 专注文件系统层: |
| | - 保存上传文件(async chunk) |
| | - 生成缩略图(图片:线程;视频:异步子进程) |
| | - 删除文件(只删 media_dir 下的文件) |
| | """ |
| | def __init__(self, media_dir: str): |
| | self.media_dir = os.path.abspath(media_dir) |
| | os.makedirs(self.media_dir, exist_ok=True) |
| | self.thumbs_dir = ensure_thumbs_dir(self.media_dir) |
| |
|
| | async def save_upload(self, uf: UploadFile, *, store_filename: str, display_name: str) -> MediaMeta: |
| | media_id = uuid.uuid4().hex[:10] |
| |
|
| | display_name = sanitize_filename(display_name or uf.filename or "unnamed") |
| | store_filename = sanitize_filename(store_filename) |
| |
|
| | kind = detect_media_kind(display_name) |
| |
|
| | save_path = os.path.join(self.media_dir, store_filename) |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| |
|
| | if os.path.exists(save_path): |
| | raise HTTPException(status_code=409, detail=f"media filename exists: {store_filename}") |
| |
|
| | |
| | async with await anyio.open_file(save_path, "wb") as out: |
| | while True: |
| | chunk = await uf.read(CHUNK_SIZE) |
| | if not chunk: |
| | break |
| | await out.write(chunk) |
| |
|
| | try: |
| | await uf.close() |
| | except Exception: |
| | pass |
| |
|
| | thumb_path: Optional[str] = None |
| | if kind in ("image", "video"): |
| | thumb_path = os.path.join(self.thumbs_dir, f"{media_id}.jpg") |
| |
|
| | if kind == "image": |
| | ok = await anyio.to_thread.run_sync(make_image_thumbnail_sync, save_path, thumb_path) |
| | else: |
| | ok = await make_video_thumbnail_async(save_path, thumb_path) |
| |
|
| | if not ok: |
| | |
| | thumb_path = save_path if kind == "image" else None |
| |
|
| | return MediaMeta( |
| | id=media_id, |
| | name=os.path.basename(display_name), |
| | kind=kind, |
| | path=os.path.abspath(save_path), |
| | thumb_path=os.path.abspath(thumb_path) if thumb_path else None, |
| | ts=time.time(), |
| | ) |
| | |
| | async def save_from_path( |
| | self, |
| | src_path: str, |
| | *, |
| | store_filename: str, |
| | display_name: str, |
| | ) -> MediaMeta: |
| | """ |
| | 将分片上传产生的临时文件移动到 media_dir 下的最终文件。 |
| | - display_name: UI 展示名(原始文件名) |
| | - store_filename: 落盘名(media_0001.mp4),用于记录顺序 |
| | """ |
| | media_id = uuid.uuid4().hex[:10] |
| |
|
| | display_name = sanitize_filename(display_name or "unnamed") |
| | store_filename = sanitize_filename(store_filename or "unnamed") |
| |
|
| | kind = detect_media_kind(display_name) |
| |
|
| | src_path = os.path.abspath(src_path) |
| | if not os.path.exists(src_path): |
| | raise HTTPException(status_code=400, detail="upload temp file missing") |
| |
|
| | save_path = os.path.abspath(os.path.join(self.media_dir, store_filename)) |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| |
|
| | if os.path.exists(save_path): |
| | raise HTTPException(status_code=409, detail=f"media already exists: {store_filename}") |
| |
|
| | |
| | os.replace(src_path, save_path) |
| |
|
| | thumb_path: Optional[str] = None |
| | if kind in ("image", "video"): |
| | thumb_path = os.path.join(self.thumbs_dir, f"{media_id}.jpg") |
| |
|
| | if kind == "image": |
| | ok = await anyio.to_thread.run_sync(make_image_thumbnail_sync, save_path, thumb_path) |
| | else: |
| | ok = await make_video_thumbnail_async(save_path, thumb_path) |
| |
|
| | if not ok: |
| | thumb_path = save_path if kind == "image" else None |
| |
|
| | return MediaMeta( |
| | id=media_id, |
| | name=os.path.basename(display_name), |
| | kind=kind, |
| | path=os.path.abspath(save_path), |
| | thumb_path=os.path.abspath(thumb_path) if thumb_path else None, |
| | ts=time.time(), |
| | ) |
| |
|
| | async def delete_files(self, meta: MediaMeta) -> None: |
| | root = self.media_dir |
| | for p in {meta.path, meta.thumb_path}: |
| | if not p: |
| | continue |
| | ap = os.path.abspath(p) |
| | if not _is_under_dir(ap, root): |
| | continue |
| | if os.path.isdir(ap): |
| | continue |
| | if os.path.exists(ap): |
| | try: |
| | os.remove(ap) |
| | except Exception: |
| | pass |
| |
|
| |
|
| | class ChatSession: |
| | """ |
| | 一个 session 的全部状态: |
| | - agent / lc_messages(LangChain上下文) |
| | - history(给前端回放) |
| | - load_media / pending_media(staging) |
| | - tool trace 索引(支持 tool 事件“就地更新”) |
| | """ |
| | def __init__(self, session_id: str, cfg: Settings): |
| | self.session_id = session_id |
| | self.cfg = cfg |
| | self.lang = "zh" |
| |
|
| | default_llm = _s(getattr(getattr(cfg, "developer", None), "default_llm", "")) or "deepseek-chat" |
| | default_vlm = _s(getattr(getattr(cfg, "developer", None), "default_vlm", "")) or "qwen3-vl-8b-instruct" |
| |
|
| | self.chat_models = [default_llm, CUSTOM_MODEL_KEY] |
| | self.chat_model_key = default_llm |
| |
|
| | self.vlm_models = [default_vlm, CUSTOM_MODEL_KEY] |
| | self.vlm_model_key = default_vlm |
| |
|
| | self.developer_mode = is_developer_mode(cfg) |
| |
|
| | self.media_dir = resolve_media_dir(cfg.project.media_dir, session_id) |
| | self.media_store = MediaStore(self.media_dir) |
| | |
| | self.uploads_dir = ensure_uploads_dir(self.media_dir) |
| | self.resumable_uploads: Dict[str, ResumableUpload] = {} |
| |
|
| | |
| | self._direct_upload_reservations = 0 |
| |
|
| | self.agent: Any = None |
| | self.node_manager = None |
| | self.client_context = None |
| | |
| | |
| | self.chat_lock = asyncio.Lock() |
| | self.media_lock = asyncio.Lock() |
| |
|
| | self.sent_media_total: int = 0 |
| | self._attach_stats_msg_idx = 1 |
| |
|
| | self.lc_messages: List[BaseMessage] = [ |
| | SystemMessage(content=get_prompt("instruction.system", lang=self.lang)), |
| | SystemMessage(content="【User media upload status】{}"), |
| | ] |
| | self.history: List[Dict[str, Any]] = [] |
| |
|
| | self.load_media: Dict[str, MediaMeta] = {} |
| | self.pending_media_ids: List[str] = [] |
| |
|
| | self._tool_history_index: Dict[str, int] = {} |
| |
|
| | self.cancel_event = asyncio.Event() |
| |
|
| | |
| | self.custom_llm_config: Optional[Dict[str, Any]] = None |
| | self.custom_vlm_config: Optional[Dict[str, Any]] = None |
| | self.tts_config: Dict[str, Any] = {} |
| | self._agent_build_key: Optional[Tuple[Any, ...]] = None |
| |
|
| | self.pexels_key_mode: str = "default" |
| | self.pexels_custom_key: str = "" |
| |
|
| | self._media_seq_inited = False |
| | self._media_seq_next = 1 |
| |
|
| | def _ensure_system_prompt(self) -> None: |
| | sys = (get_prompt("instruction.system", lang=self.lang) or "").strip() |
| | if not sys: |
| | return |
| |
|
| | for m in self.lc_messages: |
| | if isinstance(m, SystemMessage) and (getattr(m, "content", "") or "").strip() == sys: |
| | return |
| |
|
| | self.lc_messages.insert(0, SystemMessage(content=sys)) |
| |
|
| | def _init_media_seq_locked(self) -> None: |
| | """ |
| | 初始化 self._media_seq_next: |
| | - 允许 clear chat 后继续编号,不覆盖旧文件 |
| | """ |
| | if self._media_seq_inited: |
| | return |
| |
|
| | max_seq = 0 |
| |
|
| | |
| | try: |
| | for fn in os.listdir(self.media_dir): |
| | s = parse_media_seq(fn) |
| | if s is not None: |
| | max_seq = max(max_seq, s) |
| | except Exception: |
| | pass |
| |
|
| | |
| | for meta in (self.load_media or {}).values(): |
| | s = parse_media_seq(os.path.basename(meta.path or "")) |
| | if s is not None: |
| | max_seq = max(max_seq, s) |
| |
|
| | |
| | for u in (self.resumable_uploads or {}).values(): |
| | s = parse_media_seq(getattr(u, "store_filename", "") or "") |
| | if s is not None: |
| | max_seq = max(max_seq, s) |
| |
|
| | self._media_seq_next = max_seq + 1 |
| | self._media_seq_inited = True |
| |
|
| |
|
| | def _reserve_store_filenames_locked(self, display_filenames: List[str]) -> List[str]: |
| | """ |
| | 按传入顺序生成一组 store 文件名(media_0001.ext ...) |
| | 注意:这里的“顺序”就是你要固化的上传顺序。 |
| | """ |
| | self._init_media_seq_locked() |
| |
|
| | out: List[str] = [] |
| | seq = int(self._media_seq_next) |
| |
|
| | for disp in display_filenames: |
| | disp = sanitize_filename(disp or "unnamed") |
| | ext = os.path.splitext(disp)[1].lower() |
| |
|
| | |
| | while True: |
| | store = make_media_store_filename(seq, ext) |
| | if not os.path.exists(os.path.join(self.media_dir, store)): |
| | break |
| | seq += 1 |
| |
|
| | out.append(store) |
| | seq += 1 |
| |
|
| | self._media_seq_next = seq |
| | return out |
| |
|
| |
|
| | def apply_service_config(self, service_cfg: Any) -> Tuple[bool, Optional[str]]: |
| | llm, vlm, tts, pexels, err = _parse_service_config(service_cfg) |
| | if err: |
| | return False, err |
| |
|
| | if llm is not None: |
| | self.custom_llm_config = llm |
| | if vlm is not None: |
| | self.custom_vlm_config = vlm |
| |
|
| | |
| | if isinstance(tts, dict) and tts: |
| | self.tts_config = tts |
| |
|
| | |
| | if isinstance(pexels, dict) and pexels: |
| | mode = _s(pexels.get("mode")).lower() |
| | if mode == "custom": |
| | self.pexels_key_mode = "custom" |
| | self.pexels_custom_key = _s(pexels.get("api_key")) |
| | else: |
| | self.pexels_key_mode = "default" |
| | self.pexels_custom_key = "" |
| |
|
| | return True, None |
| |
|
| | async def ensure_agent(self) -> None: |
| | |
| | if self.chat_model_key == CUSTOM_MODEL_KEY: |
| | if not isinstance(self.custom_llm_config, dict): |
| | raise RuntimeError("please fill in model/base_url/api_key of custom LLM") |
| | llm_override = self.custom_llm_config |
| | else: |
| | llm_override, err = _resolve_default_model_override(self.cfg, self.chat_model_key) |
| | if err: |
| | raise RuntimeError(err) |
| |
|
| | |
| | if self.vlm_model_key == CUSTOM_MODEL_KEY: |
| | if not isinstance(self.custom_vlm_config, dict): |
| | raise RuntimeError("please fill in model/base_url/api_key of custom VLM") |
| | vlm_override = self.custom_vlm_config |
| | else: |
| | vlm_override, err = _resolve_default_model_override(self.cfg, self.vlm_model_key) |
| | if err: |
| | raise RuntimeError(err) |
| |
|
| | agent_build_key: Tuple[Any, ...] = ( |
| | "models", |
| | _stable_dict_key(llm_override), |
| | _stable_dict_key(vlm_override), |
| | ) |
| |
|
| | if self.agent is None or self._agent_build_key != agent_build_key: |
| | artifact_store = ArtifactStore(self.cfg.project.outputs_dir, session_id=self.session_id) |
| | self.agent, self.node_manager = await build_agent( |
| | cfg=self.cfg, |
| | session_id=self.session_id, |
| | store=artifact_store, |
| | tool_interceptors=[ |
| | ToolInterceptor.inject_media_content_before, |
| | ToolInterceptor.save_media_content_after, |
| | ToolInterceptor.inject_tts_config, |
| | ToolInterceptor.inject_pexels_api_key, |
| | ], |
| | llm_override=llm_override, |
| | vlm_override=vlm_override, |
| | ) |
| | self._agent_build_key = agent_build_key |
| |
|
| | if self.client_context is None: |
| | self.client_context = ClientContext( |
| | cfg=self.cfg, |
| | session_id=self.session_id, |
| | media_dir=self.media_dir, |
| | bgm_dir=self.cfg.project.bgm_dir, |
| | outputs_dir=self.cfg.project.outputs_dir, |
| | node_manager=self.node_manager, |
| | chat_model_key=self.chat_model_key, |
| | vlm_model_key=self.vlm_model_key, |
| | tts_config=(self.tts_config or None), |
| | pexels_api_key=None, |
| | lang=self.lang, |
| | ) |
| | else: |
| | self.client_context.chat_model_key = self.chat_model_key |
| | self.client_context.vlm_model_key = self.vlm_model_key |
| | self.client_context.tts_config = (self.tts_config or None) |
| | self.client_context.lang = self.lang |
| |
|
| | |
| | pexels_api_key = "" |
| | if (self.pexels_key_mode or "").lower() == "custom": |
| | pexels_api_key = _s(self.pexels_custom_key) |
| | else: |
| | pexels_api_key = _get_default_pexels_api_key(self.cfg) |
| |
|
| | self.client_context.pexels_api_key = (pexels_api_key or None) |
| |
|
| | |
| | def public_media(self, meta: MediaMeta) -> Dict[str, Any]: |
| | return { |
| | "id": meta.id, |
| | "name": meta.name, |
| | "kind": meta.kind, |
| | "thumb_url": f"/api/sessions/{self.session_id}/media/{meta.id}/thumb", |
| | "file_url": f"/api/sessions/{self.session_id}/media/{meta.id}/file", |
| | } |
| |
|
| | def public_pending_media(self) -> List[Dict[str, Any]]: |
| | out: List[Dict[str, Any]] = [] |
| | for aid in self.pending_media_ids: |
| | meta = self.load_media.get(aid) |
| | if meta: |
| | out.append(self.public_media(meta)) |
| | return out |
| |
|
| | def snapshot(self) -> Dict[str, Any]: |
| | return { |
| | "session_id": self.session_id, |
| | "developer_mode": self.developer_mode, |
| | "pending_media": self.public_pending_media(), |
| | "history": self.history, |
| | "limits": { |
| | "max_upload_files_per_request": MAX_UPLOAD_FILES_PER_REQUEST, |
| | "max_media_per_session": MAX_MEDIA_PER_SESSION, |
| | "max_pending_media_per_session": MAX_PENDING_MEDIA_PER_SESSION, |
| | "upload_chunk_bytes": UPLOAD_RESUMABLE_CHUNK_BYTES, |
| | }, |
| | "stats": { |
| | "media_count": len(self.load_media), |
| | "pending_count": len(self.pending_media_ids), |
| | "inflight_uploads": len(self.resumable_uploads), |
| | }, |
| | "chat_model_key": self.chat_model_key, |
| | "chat_models": self.chat_models, |
| | "llm_model_key": self.chat_model_key, |
| | "llm_models": self.chat_models, |
| | "vlm_model_key": self.vlm_model_key, |
| | "vlm_models": self.vlm_models, |
| | "lang": self.lang, |
| | } |
| |
|
| | |
| | def _cleanup_stale_uploads_locked(self, now: Optional[float] = None) -> None: |
| | now = float(now or time.time()) |
| | ttl = float(RESUMABLE_UPLOAD_TTL_SEC) |
| | dead = [uid for uid, u in self.resumable_uploads.items() if (now - u.last_ts) > ttl] |
| | for uid in dead: |
| | u = self.resumable_uploads.pop(uid, None) |
| | if not u: |
| | continue |
| | try: |
| | if u.tmp_path and os.path.exists(u.tmp_path): |
| | os.remove(u.tmp_path) |
| | except Exception: |
| | pass |
| |
|
| | def _check_media_caps_locked(self, add: int = 0) -> None: |
| | add = int(max(0, add)) |
| | total = len(self.load_media) + len(self.resumable_uploads) + int(self._direct_upload_reservations) |
| | pending = len(self.pending_media_ids) + len(self.resumable_uploads) + int(self._direct_upload_reservations) |
| |
|
| | if MAX_MEDIA_PER_SESSION > 0 and (total + add) > MAX_MEDIA_PER_SESSION: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"会话素材总数已达上限:{total}/{MAX_MEDIA_PER_SESSION}", |
| | ) |
| |
|
| | if MAX_PENDING_MEDIA_PER_SESSION > 0 and (pending + add) > MAX_PENDING_MEDIA_PER_SESSION: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"待发送素材数量已达上限:{pending}/{MAX_PENDING_MEDIA_PER_SESSION}", |
| | ) |
| |
|
| | async def add_uploads(self, files: List[UploadFile], store_filenames: List[str]) -> List[MediaMeta]: |
| | if len(store_filenames) != len(files): |
| | raise HTTPException(status_code=500, detail="store_filenames mismatch") |
| |
|
| | metas: List[MediaMeta] = [] |
| | for uf, store_fn in zip(files, store_filenames): |
| | display_name = sanitize_filename(uf.filename or "unnamed") |
| | metas.append(await self.media_store.save_upload( |
| | uf, |
| | store_filename=store_fn, |
| | display_name=display_name, |
| | )) |
| |
|
| | async with self.media_lock: |
| | for m in metas: |
| | self.load_media[m.id] = m |
| | self.pending_media_ids.append(m.id) |
| |
|
| | self.pending_media_ids.sort( |
| | key=lambda aid: os.path.basename(self.load_media[aid].path or "") |
| | if aid in self.load_media else "" |
| | ) |
| |
|
| | return metas |
| |
|
| | async def delete_pending_media(self, media_id: str) -> None: |
| | async with self.media_lock: |
| | if media_id not in self.pending_media_ids: |
| | raise HTTPException(status_code=400, detail="media is not pending (refuse physical delete)") |
| | self.pending_media_ids = [x for x in self.pending_media_ids if x != media_id] |
| | meta = self.load_media.pop(media_id, None) |
| |
|
| | if meta: |
| | await self.media_store.delete_files(meta) |
| |
|
| | async def take_pending_media_for_message(self, attachment_ids: Optional[List[str]]) -> List[MediaMeta]: |
| | async with self.media_lock: |
| | if attachment_ids: |
| | pick = [aid for aid in attachment_ids if aid in self.pending_media_ids] |
| | else: |
| | pick = list(self.pending_media_ids) |
| |
|
| | pick_set = set(pick) |
| | self.pending_media_ids = [aid for aid in self.pending_media_ids if aid not in pick_set] |
| | metas = [self.load_media[aid] for aid in pick if aid in self.load_media] |
| | return metas |
| |
|
| | |
| | def _ensure_tool_record(self, tcid: str, server: str, name: str, args: Any) -> Dict[str, Any]: |
| | idx = self._tool_history_index.get(tcid) |
| | if idx is None: |
| | rec = { |
| | "id": f"tool_{tcid}", |
| | "role": "tool", |
| | "tool_call_id": tcid, |
| | "server": server, |
| | "name": name, |
| | "args": args, |
| | "state": "running", |
| | "progress": 0.0, |
| | "message": "", |
| | "summary": None, |
| | "ts": time.time(), |
| | } |
| | self.history.append(rec) |
| | self._tool_history_index[tcid] = len(self.history) - 1 |
| | return rec |
| | return self.history[idx] |
| |
|
| | def apply_tool_event(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
| | et = raw.get("type") |
| | tcid = raw.get("tool_call_id") |
| | if et not in ("tool_start", "tool_progress", "tool_end") or not tcid: |
| | return None |
| |
|
| | server = raw.get("server") or "" |
| | name = raw.get("name") or "" |
| | args = raw.get("args") or {} |
| |
|
| | rec = self._ensure_tool_record(tcid, server, name, args) |
| |
|
| | if et == "tool_start": |
| | rec.update({ |
| | "server": server, |
| | "name": name, |
| | "args": args, |
| | "state": "running", |
| | "progress": 0.0, |
| | "message": "Starting...", |
| | "summary": None, |
| | }) |
| |
|
| | elif et == "tool_progress": |
| | progress = float(raw.get("progress", 0.0)) |
| | total = raw.get("total") |
| | if total and float(total) > 0: |
| | p = progress / float(total) |
| | else: |
| | p = progress / 100.0 if progress > 1 else progress |
| | p = max(0.0, min(1.0, p)) |
| | rec.update({ |
| | "state": "running", |
| | "progress": p, |
| | "message": raw.get("message") or "", |
| | }) |
| |
|
| | elif et == "tool_end": |
| | is_error = bool(raw.get("is_error")) |
| |
|
| | summary = raw.get("summary") |
| | try: |
| | json.dumps(summary, ensure_ascii=False) |
| | except Exception: |
| | summary = str(summary) if summary is not None else None |
| | rec.update({ |
| | "state": "error" if is_error else "complete", |
| | "progress": 1.0, |
| | "summary": summary, |
| | "message": raw.get("message") or rec.get("message") or "", |
| | }) |
| |
|
| | return rec |
| |
|
| |
|
| | class SessionStore: |
| | def __init__(self, cfg: Settings): |
| | self.cfg = cfg |
| | self._lock = asyncio.Lock() |
| | self._sessions: Dict[str, ChatSession] = {} |
| |
|
| | async def create(self) -> ChatSession: |
| | sid = uuid.uuid4().hex |
| | sess = ChatSession(sid, self.cfg) |
| | async with self._lock: |
| | self._sessions[sid] = sess |
| | return sess |
| |
|
| | async def get(self, sid: str) -> Optional[ChatSession]: |
| | async with self._lock: |
| | return self._sessions.get(sid) |
| |
|
| | async def get_or_404(self, sid: str) -> ChatSession: |
| | sess = await self.get(sid) |
| | if not sess: |
| | raise HTTPException(status_code=404, detail="session not found") |
| | return sess |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | cfg = load_settings(default_config_path()) |
| | app.state.cfg = cfg |
| | app.state.developer_mode = is_developer_mode(cfg) |
| | app.state.sessions = SessionStore(cfg) |
| | yield |
| |
|
| |
|
| | app = FastAPI(title="OpenStoryline Web", version="1.0.0", lifespan=lifespan) |
| |
|
| | app.add_middleware( |
| | HttpRateLimitMiddleware, |
| | limiter=RATE_LIMITER, |
| | trust_proxy_headers=RATE_LIMIT_TRUST_PROXY_HEADERS, |
| | ) |
| |
|
| | if os.path.isdir(STATIC_DIR): |
| | app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") |
| |
|
| | if os.path.isdir(NODE_MAP_DIR): |
| | app.mount("/node_map", StaticFiles(directory=NODE_MAP_DIR), name="node_map") |
| |
|
| | api = APIRouter(prefix="/api") |
| |
|
| | def _rate_limit_reject_json(retry_after: float) -> JSONResponse: |
| | ra = int(math.ceil(float(retry_after or 0.0))) |
| | return JSONResponse( |
| | {"detail": "Too Many Requests", "retry_after": ra}, |
| | status_code=429, |
| | headers={"Retry-After": str(ra)}, |
| | ) |
| |
|
| | async def _enforce_upload_media_count_limit(request: Request, cost: float) -> Optional[JSONResponse]: |
| | ip = _client_ip_from_http_scope(request.scope, RATE_LIMIT_TRUST_PROXY_HEADERS) |
| | cost = float(max(0.0, cost)) |
| |
|
| | ok, ra, _ = await RATE_LIMITER.allow( |
| | key="http:upload_media_count:all", |
| | capacity=float(UPLOAD_MEDIA_COUNT_ALL_BURST), |
| | refill_rate=_rpm_to_rps(float(UPLOAD_MEDIA_COUNT_ALL_RPM)), |
| | cost=cost, |
| | ) |
| | if not ok: |
| | return _rate_limit_reject_json(ra) |
| |
|
| | ok2, ra2, _ = await RATE_LIMITER.allow( |
| | key=f"http:upload_media_count:{ip}", |
| | capacity=float(HTTP_UPLOAD_MEDIA_COUNT_BURST), |
| | refill_rate=_rpm_to_rps(float(HTTP_UPLOAD_MEDIA_COUNT_RPM)), |
| | cost=cost, |
| | ) |
| | if not ok2: |
| | return _rate_limit_reject_json(ra2) |
| |
|
| | return None |
| |
|
| | _TTS_UI_SECRET_KEYS = { |
| | "api_key", |
| | "access_token", |
| | "authorization", |
| | "token", |
| | "password", |
| | "secret", |
| | "x-api-key", |
| | "apikey", |
| | "access_key", |
| | "accesskey", |
| | } |
| |
|
| | def _is_secret_field_name(k: str) -> bool: |
| | if str(k or "").strip().lower() in _TTS_UI_SECRET_KEYS: |
| | return True |
| | return False |
| |
|
| | def _read_config_toml(path: str) -> dict: |
| | if tomllib is None: |
| | return {} |
| | try: |
| | p = Path(path) |
| | with p.open("rb") as f: |
| | return tomllib.load(f) or {} |
| | except Exception: |
| | return {} |
| |
|
| | def _get_default_pexels_api_key(cfg: Settings) -> str: |
| | |
| | try: |
| | search_media = getattr(cfg, "search_media", None) |
| | pexels_api_key = _s(getattr(search_media, "pexels_api_key", None) if search_media else None) |
| | if pexels_api_key: |
| | return pexels_api_key |
| | else: |
| | return "" |
| | except Exception: |
| | return "" |
| |
|
| | def _normalize_field_item(item) -> dict | None: |
| | """ |
| | item 支持: |
| | - "uid" |
| | - { key="uid", label="UID", required=true, secret=false, placeholder="..." } |
| | """ |
| | if isinstance(item, str): |
| | key = item.strip() |
| | if not key: |
| | return None |
| | return { |
| | "key": key, |
| | "secret": _is_secret_field_name(key), |
| | } |
| | return None |
| |
|
| | def _build_provider_schema(provider: str, label: str | None, fields: list[dict]) -> dict: |
| | seen = set() |
| | out = [] |
| | for f in fields: |
| | k = str(f.get("key") or "").strip() |
| | if not k or k in seen: |
| | continue |
| | seen.add(k) |
| | out.append({ |
| | "key": k, |
| | "label": f.get("label") or k, |
| | "placeholder": f.get("placeholder") or f.get("label") or k, |
| | "required": bool(f.get("required", False)), |
| | "secret": bool(f.get("secret", False)), |
| | }) |
| | return {"provider": provider, "label": label or provider, "fields": out} |
| |
|
| | def _build_tts_ui_schema_from_config(config_path: str) -> dict: |
| | """ |
| | 返回: |
| | { |
| | "providers": [ |
| | {"provider":"bytedance","label":"字节跳动","fields":[{"key":"uid",...}, ...]}, |
| | ... |
| | ] |
| | } |
| | """ |
| | cfg = _read_config_toml(config_path) |
| | tts = cfg.get("generate_voiceover", {}) |
| |
|
| | providers_out: list[dict] = [] |
| |
|
| | |
| | providers = tts.get("providers") |
| | if isinstance(providers, dict): |
| | for provider, provider_cfg in providers.items(): |
| | fields: list[dict] = [] |
| | label = str(provider_cfg.get("label") or provider_cfg.get("name") or provider) |
| | for key in provider_cfg.keys(): |
| | f = _normalize_field_item(str(key)) |
| | if f: |
| | fields.append(f) |
| |
|
| | providers_out.append(_build_provider_schema(provider, label, fields)) |
| |
|
| | return {"providers": providers_out} |
| |
|
| | @app.get("/") |
| | async def index(): |
| | if not os.path.exists(INDEX_HTML): |
| | return Response("index.html not found. Put it under ./web/index.html", media_type="text/plain", status_code=404) |
| | return FileResponse(INDEX_HTML, media_type="text/html") |
| |
|
| | @app.get("/node-map") |
| | async def node_map(): |
| | if not os.path.exists(NODE_MAP_HTML): |
| | return Response( |
| | "node_map.html not found. Put it under ./web/node_map/node_map.html", |
| | media_type="text/plain", |
| | status_code=404, |
| | ) |
| | return FileResponse(NODE_MAP_HTML, media_type="text/html") |
| |
|
| | @api.get("/meta/tts") |
| | async def get_tts_ui_schema(): |
| | schema = _build_tts_ui_schema_from_config(default_config_path()) |
| | return JSONResponse(schema) |
| |
|
| | |
| | |
| | |
| | @api.post("/sessions") |
| | async def create_session(): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.create() |
| | return JSONResponse(sess.snapshot()) |
| |
|
| |
|
| | @api.get("/sessions/{session_id}") |
| | async def get_session(session_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| | return JSONResponse(sess.snapshot()) |
| |
|
| |
|
| | @api.post("/sessions/{session_id}/clear") |
| | async def clear_session_chat(session_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| | async with sess.chat_lock: |
| | sess.sent_media_total = 0 |
| | sess._attach_stats_msg_idx = 1 |
| | sess.lc_messages = [ |
| | SystemMessage(content=get_prompt("instruction.system", lang=sess.lang)), |
| | SystemMessage(content="【User media upload status】{}"), |
| | ] |
| | sess._attach_stats_msg_idx = 1 |
| |
|
| | sess.history = [] |
| | sess._tool_history_index = {} |
| | return JSONResponse({"ok": True}) |
| |
|
| | @api.post("/sessions/{session_id}/cancel") |
| | async def cancel_session_turn(session_id: str): |
| | """ |
| | 打断当前正在进行的 LLM turn(流式回复/工具调用)。 |
| | - 不清空 history / lc_messages |
| | - 仅设置 cancel_event,由 WS 侧在流式循环中感知并安全收尾 |
| | """ |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| | sess.cancel_event.set() |
| | return JSONResponse({"ok": True}) |
| |
|
| | |
| | |
| | |
| | @api.post("/sessions/{session_id}/media") |
| | async def upload_media(session_id: str, request: Request, files: List[UploadFile] = File(...)): |
| | if not isinstance(files, list) or not files: |
| | raise HTTPException(status_code=400, detail="no files") |
| |
|
| | if MAX_UPLOAD_FILES_PER_REQUEST > 0 and len(files) > MAX_UPLOAD_FILES_PER_REQUEST: |
| | raise HTTPException(status_code=400, detail=f"单次上传最多 {MAX_UPLOAD_FILES_PER_REQUEST} 个文件") |
| |
|
| | |
| | rej = await _enforce_upload_media_count_limit(request, cost=float(len(files))) |
| | if rej: |
| | return rej |
| |
|
| | if UPLOAD_SEM.locked(): |
| | raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") |
| | await UPLOAD_SEM.acquire() |
| |
|
| | n = len(files) |
| | try: |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | |
| | async with sess.media_lock: |
| | sess._cleanup_stale_uploads_locked() |
| | sess._check_media_caps_locked(add=n) |
| | sess._direct_upload_reservations += n |
| |
|
| | display_names = [sanitize_filename(uf.filename or "unnamed") for uf in files] |
| | store_filenames = sess._reserve_store_filenames_locked(display_names) |
| |
|
| | try: |
| | metas = await sess.add_uploads(files, store_filenames=store_filenames) |
| |
|
| | finally: |
| | async with sess.media_lock: |
| | sess._direct_upload_reservations = max(0, sess._direct_upload_reservations - n) |
| |
|
| | return JSONResponse({ |
| | "media": [sess.public_media(m) for m in metas], |
| | "pending_media": sess.public_pending_media(), |
| | }) |
| | finally: |
| | try: |
| | UPLOAD_SEM.release() |
| | except Exception: |
| | pass |
| |
|
| | @api.post("/sessions/{session_id}/media/init") |
| | async def init_resumable_media_upload(session_id: str, request: Request): |
| | try: |
| | data = await request.json() |
| | if not isinstance(data, dict): |
| | data = {} |
| | except Exception: |
| | data = {} |
| |
|
| | filename = sanitize_filename((data.get("filename") or data.get("name") or "unnamed")) |
| | size = int(data.get("size") or 0) |
| | if size <= 0: |
| | raise HTTPException(status_code=400, detail="invalid size") |
| |
|
| | |
| | rej = await _enforce_upload_media_count_limit(request, cost=1.0) |
| | if rej: |
| | return rej |
| |
|
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | async with sess.media_lock: |
| | sess._cleanup_stale_uploads_locked() |
| | sess._check_media_caps_locked(add=1) |
| |
|
| | store_filename = sess._reserve_store_filenames_locked([filename])[0] |
| |
|
| | upload_id = uuid.uuid4().hex |
| | chunk_size = int(max(1, UPLOAD_RESUMABLE_CHUNK_BYTES)) |
| | total_chunks = int(math.ceil(size / float(chunk_size))) |
| |
|
| | tmp_path = os.path.join(sess.uploads_dir, f"{upload_id}.part") |
| | os.makedirs(os.path.dirname(tmp_path), exist_ok=True) |
| | try: |
| | with open(tmp_path, "wb"): |
| | pass |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"cannot create temp file: {e}") |
| |
|
| | u = ResumableUpload( |
| | upload_id=upload_id, |
| | filename=filename, |
| | store_filename=store_filename, |
| | size=size, |
| | chunk_size=chunk_size, |
| | total_chunks=total_chunks, |
| | tmp_path=os.path.abspath(tmp_path), |
| | kind=detect_media_kind(filename), |
| | created_ts=time.time(), |
| | last_ts=time.time(), |
| | ) |
| | sess.resumable_uploads[upload_id] = u |
| |
|
| | return JSONResponse({ |
| | "upload_id": upload_id, |
| | "chunk_size": chunk_size, |
| | "total_chunks": total_chunks, |
| | "filename": filename, |
| | }) |
| |
|
| |
|
| | @api.post("/sessions/{session_id}/media/{upload_id}/chunk") |
| | async def upload_resumable_media_chunk( |
| | session_id: str, |
| | upload_id: str, |
| | index: int = Form(...), |
| | chunk: UploadFile = File(...), |
| | ): |
| | if UPLOAD_SEM.locked(): |
| | raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") |
| | await UPLOAD_SEM.acquire() |
| | try: |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | async with sess.media_lock: |
| | sess._cleanup_stale_uploads_locked() |
| | u = sess.resumable_uploads.get(upload_id) |
| |
|
| | if not u: |
| | raise HTTPException(status_code=404, detail="upload_id not found or expired") |
| |
|
| | idx = int(index) |
| | if idx < 0 or idx >= u.total_chunks: |
| | raise HTTPException(status_code=400, detail="invalid chunk index") |
| |
|
| | |
| | expected_len = u.size - idx * u.chunk_size |
| | if expected_len <= 0: |
| | raise HTTPException(status_code=400, detail="invalid chunk index") |
| | expected_len = min(u.chunk_size, expected_len) |
| |
|
| | written = 0 |
| | async with u.lock: |
| | if u.closed: |
| | raise HTTPException(status_code=400, detail="upload already closed") |
| |
|
| | async with await anyio.open_file(u.tmp_path, "r+b") as out: |
| | await out.seek(idx * u.chunk_size) |
| | while True: |
| | buf = await chunk.read(CHUNK_SIZE) |
| | if not buf: |
| | break |
| | written += len(buf) |
| | if written > expected_len: |
| | raise HTTPException(status_code=400, detail="chunk too large") |
| | await out.write(buf) |
| |
|
| | try: |
| | await chunk.close() |
| | except Exception: |
| | pass |
| |
|
| | if written != expected_len: |
| | raise HTTPException(status_code=400, detail=f"chunk size mismatch: {written} != {expected_len}") |
| |
|
| | u.received.add(idx) |
| | u.last_ts = time.time() |
| |
|
| | return JSONResponse({ |
| | "ok": True, |
| | "received_chunks": len(u.received), |
| | "total_chunks": u.total_chunks, |
| | }) |
| | finally: |
| | try: |
| | UPLOAD_SEM.release() |
| | except Exception: |
| | pass |
| |
|
| |
|
| | @api.post("/sessions/{session_id}/media/{upload_id}/complete") |
| | async def complete_resumable_media_upload(session_id: str, upload_id: str): |
| | if UPLOAD_SEM.locked(): |
| | raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") |
| | await UPLOAD_SEM.acquire() |
| | try: |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | async with sess.media_lock: |
| | sess._cleanup_stale_uploads_locked() |
| | u = sess.resumable_uploads.get(upload_id) |
| |
|
| | if not u: |
| | raise HTTPException(status_code=404, detail="upload_id not found or expired") |
| |
|
| | |
| | async with u.lock: |
| | u.closed = True |
| | if len(u.received) != u.total_chunks: |
| | missing = u.total_chunks - len(u.received) |
| | raise HTTPException(status_code=400, detail=f"chunks missing: {missing}") |
| |
|
| | |
| | async with sess.media_lock: |
| | u2 = sess.resumable_uploads.pop(upload_id, None) |
| |
|
| | if not u2: |
| | raise HTTPException(status_code=404, detail="upload_id not found") |
| |
|
| | meta = await sess.media_store.save_from_path( |
| | u2.tmp_path, |
| | store_filename=u2.store_filename, |
| | display_name=u2.filename, |
| | ) |
| |
|
| | async with sess.media_lock: |
| | sess.load_media[meta.id] = meta |
| | sess.pending_media_ids.append(meta.id) |
| |
|
| | return JSONResponse({ |
| | "media": sess.public_media(meta), |
| | "pending_media": sess.public_pending_media(), |
| | }) |
| | finally: |
| | try: |
| | UPLOAD_SEM.release() |
| | except Exception: |
| | pass |
| |
|
| |
|
| | @api.post("/sessions/{session_id}/media/{upload_id}/cancel") |
| | async def cancel_resumable_media_upload(session_id: str, upload_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | async with sess.media_lock: |
| | u = sess.resumable_uploads.pop(upload_id, None) |
| |
|
| | if not u: |
| | return JSONResponse({"ok": True}) |
| |
|
| | async with u.lock: |
| | u.closed = True |
| | try: |
| | if u.tmp_path and os.path.exists(u.tmp_path): |
| | os.remove(u.tmp_path) |
| | except Exception: |
| | pass |
| |
|
| | return JSONResponse({"ok": True}) |
| |
|
| | @api.get("/sessions/{session_id}/media/pending") |
| | async def get_pending_media(session_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| | return JSONResponse({"pending_media": sess.public_pending_media()}) |
| |
|
| |
|
| | @api.delete("/sessions/{session_id}/media/pending/{media_id}") |
| | async def delete_pending_media(session_id: str, media_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| | await sess.delete_pending_media(media_id) |
| | return JSONResponse({"ok": True, "pending_media": sess.public_pending_media()}) |
| |
|
| |
|
| | @api.get("/sessions/{session_id}/media/{media_id}/thumb") |
| | async def get_media_thumb(session_id: str, media_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | meta = sess.load_media.get(media_id) |
| | if not meta: |
| | raise HTTPException(status_code=404, detail="media not found") |
| |
|
| | |
| | if meta.thumb_path and os.path.exists(meta.thumb_path): |
| | return FileResponse(meta.thumb_path, media_type="image/jpeg") |
| |
|
| | |
| | if meta.kind == "video": |
| | return Response(content=video_placeholder_svg_bytes(), media_type="image/svg+xml") |
| |
|
| | |
| | if meta.path and os.path.exists(meta.path): |
| | return FileResponse(meta.path, media_type=guess_media_type(meta.path)) |
| |
|
| | raise HTTPException(status_code=404, detail="thumb not available") |
| |
|
| |
|
| | @api.get("/sessions/{session_id}/media/{media_id}/file") |
| | async def get_media_file(session_id: str, media_id: str): |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | meta = sess.load_media.get(media_id) |
| | if not meta: |
| | raise HTTPException(status_code=404, detail="media not found") |
| | if not meta.path or (not os.path.exists(meta.path)): |
| | raise HTTPException(status_code=404, detail="file not found") |
| |
|
| | |
| | if not _is_under_dir(meta.path, sess.media_store.media_dir): |
| | raise HTTPException(status_code=403, detail="forbidden") |
| |
|
| | return FileResponse( |
| | meta.path, |
| | media_type=guess_media_type(meta.path), |
| | filename=meta.name, |
| | ) |
| |
|
| | @api.get("/sessions/{session_id}/preview") |
| | async def preview_local_file(session_id: str, path: str): |
| | """ |
| | 把 summary.preview_urls 里的“服务器本地路径”安全地转成可访问 URL。 |
| | 只允许访问:media_dir / outputs_dir / outputs_dir / bgm_dir / .server_cache 这些根目录下的文件。 |
| | """ |
| | store: SessionStore = app.state.sessions |
| | sess = await store.get_or_404(session_id) |
| |
|
| | p = (path or "").strip() |
| | if not p: |
| | raise HTTPException(status_code=400, detail="empty path") |
| | if "\x00" in p: |
| | raise HTTPException(status_code=400, detail="bad path") |
| |
|
| | |
| | if p.startswith("file://"): |
| | p = p[len("file://"):] |
| |
|
| | |
| | if os.path.isabs(p): |
| | ap = os.path.abspath(p) |
| | else: |
| | ap = os.path.abspath(os.path.join(ROOT_DIR, p)) |
| |
|
| | allowed_roots = [ |
| | os.path.abspath(sess.media_dir), |
| | os.path.abspath(app.state.cfg.project.outputs_dir), |
| | os.path.abspath(app.state.cfg.project.outputs_dir), |
| | os.path.abspath(app.state.cfg.project.bgm_dir), |
| | os.path.abspath(SERVER_CACHE_DIR), |
| | ] |
| |
|
| | if not any(_is_under_dir(ap, r) for r in allowed_roots): |
| | raise HTTPException(status_code=403, detail="forbidden") |
| |
|
| | if (not os.path.exists(ap)) or os.path.isdir(ap): |
| | raise HTTPException(status_code=404, detail="file not found") |
| |
|
| | |
| | headers = {"Cache-Control": "public, max-age=31536000, immutable"} if _is_under_dir(ap, SERVER_CACHE_DIR) else None |
| |
|
| | return FileResponse( |
| | ap, |
| | media_type=guess_media_type(ap), |
| | filename=os.path.basename(ap), |
| | headers=headers, |
| | ) |
| |
|
| | app.include_router(api) |
| |
|
| |
|
| | |
| | |
| | |
| | def extract_text_delta(msg_chunk: Any) -> str: |
| | |
| | blocks = getattr(msg_chunk, "content_blocks", None) or [] |
| | if blocks: |
| | out = "" |
| | for b in blocks: |
| | if isinstance(b, dict) and b.get("type") == "text": |
| | out += b.get("text", "") |
| | return out |
| | c = getattr(msg_chunk, "content", "") |
| | return c if isinstance(c, str) else "" |
| |
|
| |
|
| | async def ws_send(ws: WebSocket, type_: str, data: Any = None): |
| | if getattr(ws, "client_state", None) != WebSocketState.CONNECTED: |
| | return False |
| | try: |
| | await ws.send_json({"type": type_, "data": data}) |
| | return True |
| | except WebSocketDisconnect: |
| | return False |
| | except RuntimeError: |
| | return False |
| | except Exception as e: |
| | if ClientDisconnected is not None and isinstance(e, ClientDisconnected): |
| | return False |
| | logger.exception("ws_send failed: type=%s err=%r", type_, e) |
| | return False |
| |
|
| | @asynccontextmanager |
| | async def mcp_sink_context(sink_func): |
| | token = set_mcp_log_sink(sink_func) |
| | try: |
| | yield |
| | finally: |
| | reset_mcp_log_sink(token) |
| |
|
| |
|
| | @app.websocket("/ws/sessions/{session_id}/chat") |
| | async def ws_chat(ws: WebSocket, session_id: str): |
| | client_ip = _client_ip_from_ws(ws, RATE_LIMIT_TRUST_PROXY_HEADERS) |
| |
|
| | ok, retry_after, _ = await RATE_LIMITER.allow( |
| | key=f"ws:connect:{client_ip}", |
| | capacity=float(WS_CONNECT_BURST), |
| | refill_rate=_rpm_to_rps(float(WS_CONNECT_RPM)), |
| | cost=1.0, |
| | ) |
| | if not ok: |
| | try: |
| | await ws.close(code=1013, reason=f"rate_limited, retry after {int(math.ceil(retry_after))}s") |
| | except Exception: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | return |
| | |
| | if WS_CONN_SEM.locked(): |
| | try: |
| | await ws.close(code=1013, reason="Server busy (websocket connections limit)") |
| | except Exception: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | return |
| | |
| | await WS_CONN_SEM.acquire() |
| |
|
| | try: |
| | await ws.accept() |
| |
|
| | store: SessionStore = app.state.sessions |
| | sess = await store.get(session_id) |
| | if not sess: |
| | await ws.close(code=4404, reason="session not found") |
| | return |
| | sess = await store.get_or_404(session_id) |
| |
|
| | await ws_send(ws, "session.snapshot", sess.snapshot()) |
| |
|
| | try: |
| | while True: |
| | req = await ws.receive_json() |
| | if not isinstance(req, dict): |
| | continue |
| |
|
| | t = req.get("type") |
| | if t == "ping": |
| | await ws_send(ws, "pong", {"ts": time.time()}) |
| | continue |
| |
|
| | if t == "session.set_lang": |
| | data = (req.get("data") or {}) |
| | lang = (data.get("lang") or "").strip().lower() |
| | if lang not in ("zh", "en"): |
| | lang = "zh" |
| |
|
| | sess.lang = lang |
| | if sess.client_context: |
| | sess.client_context.lang = lang |
| |
|
| | await ws_send(ws, "session.lang", {"lang": lang}) |
| | continue |
| |
|
| | if t == "chat.clear": |
| | async with sess.chat_lock: |
| | sess.sent_media_total = 0 |
| | sess._attach_stats_msg_idx = 1 |
| | sess.lc_messages = [ |
| | SystemMessage(content=get_prompt("instruction.system", lang=sess.lang)), |
| | SystemMessage(content="【User media upload status】{}"), |
| | ] |
| | sess._attach_stats_msg_idx = 1 |
| | sess.history = [] |
| | sess._tool_history_index = {} |
| | await ws_send(ws, "chat.cleared", {"ok": True}) |
| | continue |
| |
|
| | if t != "chat.send": |
| | await ws_send(ws, "error", {"message": f"unknown type: {t}"}) |
| | continue |
| |
|
| | |
| | if sess.chat_lock.locked(): |
| | await ws_send(ws, "error", {"message": "上一条消息尚未完成,请稍后再发送"}) |
| | continue |
| | |
| | ok, retry_after, _ = await RATE_LIMITER.allow( |
| | key="ws:chat_send:all", |
| | capacity=float(WS_CHAT_SEND_ALL_BURST), |
| | refill_rate=_rpm_to_rps(float(WS_CHAT_SEND_ALL_RPM)), |
| | cost=1.0, |
| | ) |
| | if not ok: |
| | await ws_send(ws, "error", { |
| | "message": f"触发全局限流:请 {int(math.ceil(retry_after))} 秒后再试", |
| | "retry_after": int(math.ceil(retry_after)), |
| | }) |
| | continue |
| |
|
| | ok, retry_after, _ = await RATE_LIMITER.allow( |
| | key=f"ws:chat_send:{client_ip}", |
| | capacity=float(WS_CHAT_SEND_BURST), |
| | refill_rate=_rpm_to_rps(float(WS_CHAT_SEND_RPM)), |
| | cost=1.0, |
| | ) |
| | if not ok: |
| | await ws_send(ws, "error", { |
| | "message": f"触发限流:请 {int(math.ceil(retry_after))} 秒后再试", |
| | "retry_after": int(math.ceil(retry_after)), |
| | }) |
| | continue |
| |
|
| | if CHAT_TURN_SEM.locked(): |
| | await ws_send(ws, "error", {"message": "服务器繁忙(模型并发已满),请稍后再试"}) |
| | continue |
| |
|
| | await CHAT_TURN_SEM.acquire() |
| | try: |
| | |
| | if sess.chat_lock.locked(): |
| | await ws_send(ws, "error", {"message": "上一条消息尚未完成,请稍后再发送"}) |
| | continue |
| |
|
| | data = (req.get("data", {}) or {}) |
| |
|
| | prompt = data.get("text", "") |
| | prompt = (prompt or "").strip() |
| | if not prompt: |
| | continue |
| |
|
| | requested_llm = data.get("llm_model") |
| | requested_vlm = data.get("vlm_model") |
| |
|
| | attachment_ids = data.get("attachment_ids") |
| | if not isinstance(attachment_ids, list): |
| | attachment_ids = None |
| |
|
| | async with sess.chat_lock: |
| | |
| | sess.cancel_event.clear() |
| | |
| | ok_cfg, err_cfg = sess.apply_service_config(data.get("service_config")) |
| | if not ok_cfg: |
| | await ws_send(ws, "error", {"message": err_cfg or "service_config invalid"}) |
| | continue |
| |
|
| | |
| | if isinstance(requested_llm, str): |
| | m = requested_llm.strip() |
| | if m: |
| | sess.chat_model_key = m |
| | if sess.client_context: |
| | sess.client_context.chat_model_key = m |
| |
|
| | if isinstance(requested_vlm, str): |
| | m2 = requested_vlm.strip() |
| | if m2: |
| | sess.vlm_model_key = m2 |
| | if sess.client_context: |
| | sess.client_context.vlm_model_key = m2 |
| |
|
| | requested_lang = data.get("lang") |
| | if isinstance(requested_lang, str): |
| | lang = requested_lang.strip().lower() |
| | if lang in ("zh", "en"): |
| | sess.lang = lang |
| | |
| | try: |
| | await sess.ensure_agent() |
| | except Exception as e: |
| | await ws_send(ws, "error", {"message": f"{type(e).__name__}: {e}"}) |
| | continue |
| |
|
| | sess._ensure_system_prompt() |
| |
|
| | if sess.client_context: |
| | sess.client_context.lang = sess.lang |
| |
|
| | |
| | attachments = await sess.take_pending_media_for_message(attachment_ids) |
| | attachments_public = [sess.public_media(m) for m in attachments] |
| |
|
| | |
| | turn_attached_count = len(attachments) |
| | sess.sent_media_total = int(getattr(sess, "sent_media_total", 0)) + turn_attached_count |
| |
|
| | stats = { |
| | "Number of media carried in this message sent by the user": turn_attached_count, |
| | "Total number of media sent by the user in all conversations": sess.sent_media_total, |
| | "Total number of media in user's media library": scan_media_dir(resolve_media_dir(app.state.cfg.project.media_dir, session_id=session_id)), |
| | } |
| |
|
| | idx = int(getattr(sess, "_attach_stats_msg_idx", 1)) |
| | if len(sess.lc_messages) <= idx: |
| | while len(sess.lc_messages) <= idx: |
| | sess.lc_messages.append(SystemMessage(content="")) |
| |
|
| | sess.lc_messages[idx] = SystemMessage( |
| | content="【User media upload status】The following fields are used to determine the nature of the media provided by the user: \n" |
| | + json.dumps(stats, ensure_ascii=False) |
| | ) |
| |
|
| |
|
| | |
| | user_msg = { |
| | "id": uuid.uuid4().hex[:12], |
| | "role": "user", |
| | "content": prompt, |
| | "attachments": attachments_public, |
| | "ts": time.time(), |
| | } |
| | sess.history.append(user_msg) |
| | sess.lc_messages.append(HumanMessage(content=prompt)) |
| |
|
| | |
| | |
| |
|
| | |
| | await ws_send(ws, "chat.user", { |
| | "text": prompt, |
| | "attachments": attachments_public, |
| | "pending_media": sess.public_pending_media(), |
| | "llm_model_key": sess.chat_model_key, |
| | "vlm_model_key": sess.vlm_model_key, |
| | }) |
| |
|
| | |
| | loop = asyncio.get_running_loop() |
| | out_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue() |
| |
|
| | def sink(ev: Any): |
| | |
| | if isinstance(ev, dict): |
| | loop.call_soon_threadsafe(out_q.put_nowait, ("mcp", ev)) |
| |
|
| | new_messages: List[BaseMessage] = [] |
| |
|
| | async def pump_agent(): |
| | nonlocal new_messages |
| | try: |
| | stream = sess.agent.astream( |
| | {"messages": sess.lc_messages}, |
| | context=sess.client_context, |
| | stream_mode=["messages", "updates"], |
| | ) |
| | async for mode, chunk in stream: |
| | if mode == "messages": |
| | msg_chunk, meta = chunk |
| | if meta.get("langgraph_node") == "model": |
| | delta = extract_text_delta(msg_chunk) |
| | if delta: |
| | await out_q.put(("assistant.delta", delta)) |
| |
|
| | elif mode == "updates": |
| | if isinstance(chunk, dict): |
| | for _step, data in chunk.items(): |
| | msgs = (data or {}).get("messages") or [] |
| | new_messages.extend(msgs) |
| |
|
| | await out_q.put(("agent.done", None)) |
| | except asyncio.CancelledError: |
| | |
| | |
| | try: |
| | out_q.put_nowait(("agent.cancelled", None)) |
| | except Exception: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | raise |
| |
|
| | except Exception as e: |
| | |
| | await out_q.put(("agent.error", f"{type(e).__name__}: {e}")) |
| |
|
| |
|
| | async def safe_send(type_: str, data: Any = None) -> bool: |
| | try: |
| | await ws_send(ws, type_, data) |
| | return True |
| | except WebSocketDisconnect: |
| | return False |
| | except RuntimeError as e: |
| | |
| | if 'Cannot call "send" once a close message has been sent.' in str(e): |
| | return False |
| | raise |
| | except Exception as e: |
| | |
| | if e.__class__.__name__ == "ClientDisconnected": |
| | return False |
| | raise |
| | |
| | if not await ws_send(ws, "assistant.start", {}): |
| | return |
| |
|
| | |
| | seg_text = "" |
| | seg_ts: Optional[float] = None |
| |
|
| | async def flush_segment(send_flush_event: bool): |
| | """ |
| | - send_flush_event=True:告诉前端立刻结束当前 assistant 气泡(不结束整个 turn) |
| | - 若 seg_text 有内容:写入 history(用于刷新/回放) |
| | """ |
| | nonlocal seg_text, seg_ts |
| |
|
| | if send_flush_event: |
| | if not await ws_send(ws, "assistant.flush", {}): |
| | return |
| |
|
| | text = (seg_text or "").strip() |
| | if text: |
| | sess.history.append({ |
| | "id": uuid.uuid4().hex[:12], |
| | "role": "assistant", |
| | "content": text, |
| | "ts": seg_ts or time.time(), |
| | }) |
| |
|
| | seg_text = "" |
| | seg_ts = None |
| |
|
| | pump_task: Optional[asyncio.Task] = None |
| |
|
| | |
| | def _tool_call_ids_from_ai_message(m: BaseMessage) -> set[str]: |
| | ids: set[str] = set() |
| |
|
| | tc = getattr(m, "tool_calls", None) or [] |
| | for c in tc: |
| | _id = None |
| | if isinstance(c, dict): |
| | _id = c.get("id") or c.get("tool_call_id") |
| | else: |
| | _id = getattr(c, "id", None) or getattr(c, "tool_call_id", None) |
| | if _id: |
| | ids.add(str(_id)) |
| |
|
| | ak = getattr(m, "additional_kwargs", None) or {} |
| | tc2 = ak.get("tool_calls") or [] |
| | for c in tc2: |
| | if isinstance(c, dict): |
| | _id = c.get("id") or c.get("tool_call_id") |
| | if _id: |
| | ids.add(str(_id)) |
| |
|
| | return ids |
| |
|
| | |
| | def _tool_call_ids_in_msgs(msgs: List[BaseMessage]) -> set[str]: |
| | ids: set[str] = set() |
| | for m in msgs: |
| | if isinstance(m, AIMessage): |
| | ids |= _tool_call_ids_from_ai_message(m) |
| | return ids |
| |
|
| | |
| | def _tool_result_ids_in_msgs(msgs: List[BaseMessage]) -> set[str]: |
| | ids: set[str] = set() |
| | for m in msgs: |
| | if isinstance(m, ToolMessage): |
| | tcid = getattr(m, "tool_call_id", None) |
| | if tcid: |
| | ids.add(str(tcid)) |
| | return ids |
| |
|
| | |
| | def _force_cancelled_tool_results(msgs: List[BaseMessage], cancel_ids: set[str]) -> List[BaseMessage]: |
| | if not cancel_ids: |
| | return msgs |
| | cancelled_content = json.dumps({"cancelled": True}, ensure_ascii=False) |
| | out: List[BaseMessage] = [] |
| | for m in msgs: |
| | if isinstance(m, ToolMessage): |
| | tcid = getattr(m, "tool_call_id", None) |
| | if tcid and str(tcid) in cancel_ids: |
| | out.append(ToolMessage(content=cancelled_content, tool_call_id=str(tcid))) |
| | continue |
| | out.append(m) |
| | return out |
| |
|
| | def _inject_cancelled_tool_messages(msgs: List[BaseMessage], tool_call_ids: List[str]) -> List[BaseMessage]: |
| | if not tool_call_ids: |
| | return msgs |
| |
|
| | out = list(msgs) |
| |
|
| | existing = set() |
| | for m in out: |
| | if isinstance(m, ToolMessage): |
| | tcid = getattr(m, "tool_call_id", None) |
| | if tcid: |
| | existing.add(str(tcid)) |
| |
|
| | cancelled_content = json.dumps({"cancelled": True}, ensure_ascii=False) |
| |
|
| | for tcid in tool_call_ids: |
| | tcid = str(tcid) |
| | if tcid in existing: |
| | continue |
| |
|
| | insert_at = None |
| | for i in range(len(out) - 1, -1, -1): |
| | m = out[i] |
| | if isinstance(m, AIMessage) and (tcid in _tool_call_ids_from_ai_message(m)): |
| | insert_at = i + 1 |
| | break |
| |
|
| | if insert_at is None: |
| | continue |
| |
|
| | out.insert(insert_at, ToolMessage(content=cancelled_content, tool_call_id=tcid)) |
| | existing.add(tcid) |
| |
|
| | return out |
| |
|
| | def _sanitize_new_messages_on_cancel( |
| | new_messages: List[BaseMessage], |
| | *, |
| | interrupted_text: str, |
| | cancelled_tool_ids_from_ui: List[str], |
| | ) -> List[BaseMessage]: |
| | """ |
| | 返回:应该写回 sess.lc_messages 的消息序列(只包含“用户可见/认可”的那部分) |
| | - 工具:对未返回的 tool_call 补 ToolMessage({"cancelled": true}) |
| | - 回复:用 interrupted_text 替换末尾 final AIMessage,避免把完整回复泄漏进上下文 |
| | """ |
| | msgs = list(new_messages or []) |
| | interrupted_text = (interrupted_text or "").strip() |
| |
|
| | |
| | ai_tool_ids = _tool_call_ids_in_msgs(msgs) |
| | tool_result_ids = _tool_result_ids_in_msgs(msgs) |
| | pending_tool_ids = ai_tool_ids - tool_result_ids |
| |
|
| | |
| | ui_cancel_ids = {str(x) for x in (cancelled_tool_ids_from_ui or [])} |
| |
|
| | |
| | |
| | |
| | cancel_ids = set(ui_cancel_ids) | set(pending_tool_ids) |
| |
|
| | |
| | |
| | msgs = _force_cancelled_tool_results(msgs, cancel_ids) |
| |
|
| | |
| | msgs = _inject_cancelled_tool_messages(msgs, list(cancel_ids)) |
| |
|
| | |
| | |
| | |
| | def _is_toolcall_ai(m: BaseMessage) -> bool: |
| | return isinstance(m, AIMessage) and bool(_tool_call_ids_from_ai_message(m)) |
| |
|
| | def _is_text_ai(m: BaseMessage) -> bool: |
| | if not isinstance(m, AIMessage): |
| | return False |
| | if _tool_call_ids_from_ai_message(m): |
| | return False |
| | c = getattr(m, "content", None) |
| | return isinstance(c, str) and bool(c.strip()) |
| |
|
| | |
| | last_text_ai_idx = None |
| | for i in range(len(msgs) - 1, -1, -1): |
| | if _is_text_ai(msgs[i]): |
| | last_text_ai_idx = i |
| | break |
| |
|
| | if interrupted_text: |
| | if last_text_ai_idx is None: |
| | msgs.append(AIMessage(content=interrupted_text)) |
| | else: |
| | |
| | msgs = msgs[:last_text_ai_idx] + [AIMessage(content=interrupted_text)] |
| | return msgs |
| |
|
| | |
| | |
| | |
| | if last_text_ai_idx is not None: |
| | has_toolcall_after = any(_is_toolcall_ai(m) for m in msgs[last_text_ai_idx + 1 :]) |
| | if not has_toolcall_after: |
| | msgs = msgs[:last_text_ai_idx] |
| |
|
| | return msgs |
| |
|
| | pump_task: Optional[asyncio.Task] = None |
| | cancel_wait_task: Optional[asyncio.Task] = None |
| |
|
| | was_interrupted = False |
| |
|
| | try: |
| | async with mcp_sink_context(sink): |
| | pump_task = asyncio.create_task(pump_agent()) |
| | cancel_wait_task = asyncio.create_task(sess.cancel_event.wait()) |
| |
|
| | while True: |
| | |
| | get_task = asyncio.create_task(out_q.get()) |
| | done, _ = await asyncio.wait( |
| | {get_task, cancel_wait_task}, |
| | return_when=asyncio.FIRST_COMPLETED, |
| | ) |
| |
|
| | |
| | if get_task in done: |
| | kind, payload = get_task.result() |
| | else: |
| | |
| | try: |
| | get_task.cancel() |
| | await get_task |
| | except asyncio.CancelledError: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | except Exception: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| |
|
| | kind, payload = ("agent.cancelled", None) |
| |
|
| | |
| | |
| | |
| | if kind == "agent.cancelled": |
| | |
| | if was_interrupted: |
| | break |
| | was_interrupted = True |
| | |
| | if pump_task and (not pump_task.done()): |
| | pump_task.cancel() |
| |
|
| | |
| | cancelled_tool_recs: List[Dict[str, Any]] = [] |
| | for tcid, idx in list(sess._tool_history_index.items()): |
| | rec = sess.history[idx] |
| | if rec.get("role") == "tool" and rec.get("state") == "running": |
| | rec.update({ |
| | "state": "error", |
| | "progress": 1.0, |
| | "message": "Cancelled by user", |
| | "summary": {"cancelled": True}, |
| | }) |
| | cancelled_tool_recs.append(rec) |
| |
|
| | |
| | for rec in cancelled_tool_recs: |
| | await ws_send(ws, "tool.end", { |
| | "tool_call_id": rec["tool_call_id"], |
| | "server": rec["server"], |
| | "name": rec["name"], |
| | "is_error": True, |
| | "summary": rec.get("summary"), |
| | }) |
| | |
| | interrupted_text = (seg_text or "").strip() |
| | if interrupted_text: |
| | sess.history.append({ |
| | "id": uuid.uuid4().hex[:12], |
| | "role": "assistant", |
| | "content": interrupted_text, |
| | "ts": seg_ts or time.time(), |
| | }) |
| |
|
| | |
| | cancelled_tool_ids = [rec["tool_call_id"] for rec in cancelled_tool_recs] |
| |
|
| | commit_msgs = _sanitize_new_messages_on_cancel( |
| | new_messages, |
| | interrupted_text=interrupted_text, |
| | cancelled_tool_ids_from_ui=cancelled_tool_ids, |
| | ) |
| |
|
| | if commit_msgs: |
| | sess.lc_messages.extend(commit_msgs) |
| | elif interrupted_text: |
| | |
| | sess.lc_messages.append(AIMessage(content=interrupted_text)) |
| |
|
| |
|
| | |
| | await ws_send(ws, "assistant.end", {"text": interrupted_text, "interrupted": True}) |
| |
|
| | sess.cancel_event.clear() |
| | break |
| |
|
| | |
| | |
| | |
| | if kind == "assistant.delta": |
| | delta = payload or "" |
| | if delta: |
| | if seg_ts is None: |
| | seg_ts = time.time() |
| | seg_text += delta |
| | if not await ws_send(ws, "assistant.delta", {"delta": delta}): |
| | raise WebSocketDisconnect() |
| | continue |
| |
|
| | if kind == "mcp": |
| | raw = payload |
| |
|
| | if raw.get("type") == "tool_start": |
| | await flush_segment(send_flush_event=True) |
| |
|
| | rec = sess.apply_tool_event(raw) |
| | if rec: |
| | if raw["type"] == "tool_start": |
| | await ws_send(ws, "tool.start", { |
| | "tool_call_id": rec["tool_call_id"], |
| | "server": rec["server"], |
| | "name": rec["name"], |
| | "args": rec["args"], |
| | }) |
| | elif raw["type"] == "tool_progress": |
| | await ws_send(ws, "tool.progress", { |
| | "tool_call_id": rec["tool_call_id"], |
| | "server": rec["server"], |
| | "name": rec["name"], |
| | "progress": rec["progress"], |
| | "message": rec["message"], |
| | }) |
| | elif raw["type"] == "tool_end": |
| | await ws_send(ws, "tool.end", { |
| | "tool_call_id": rec["tool_call_id"], |
| | "server": rec["server"], |
| | "name": rec["name"], |
| | "is_error": rec["state"] == "error", |
| | "summary": rec["summary"], |
| | }) |
| | continue |
| |
|
| | if kind == "agent.done": |
| | final_text = (seg_text or "").strip() |
| |
|
| | if final_text: |
| | sess.history.append({ |
| | "id": uuid.uuid4().hex[:12], |
| | "role": "assistant", |
| | "content": final_text, |
| | "ts": seg_ts or time.time(), |
| | }) |
| |
|
| | if new_messages: |
| | sess.lc_messages.extend(new_messages) |
| |
|
| | if not await ws_send(ws, "assistant.end", {"text": final_text}): |
| | return |
| | break |
| |
|
| | if kind == "agent.error": |
| | err_text = str(payload or "unknown error") |
| | partial = (seg_text or "").strip() |
| |
|
| | |
| | if partial: |
| | sess.history.append({ |
| | "id": uuid.uuid4().hex[:12], |
| | "role": "assistant", |
| | "content": partial, |
| | "ts": seg_ts or time.time(), |
| | }) |
| | sess.lc_messages.append(AIMessage(content=partial)) |
| |
|
| | if new_messages: |
| | sess.lc_messages.extend(new_messages) |
| |
|
| | |
| | await ws_send(ws, "error", {"message": err_text, "partial_text": partial}) |
| | break |
| | |
| | except WebSocketDisconnect: |
| | return |
| | except asyncio.CancelledError: |
| | |
| | return |
| | except Exception as e: |
| | |
| | if was_interrupted: |
| | return |
| | await ws_send(ws, "error", {"message": f"{type(e).__name__}: {e}", "partial_text": (seg_text or "").strip()}) |
| | return |
| | finally: |
| | |
| | if cancel_wait_task and (not cancel_wait_task.done()): |
| | cancel_wait_task.cancel() |
| |
|
| | |
| | if pump_task and (not pump_task.done()): |
| | pump_task.cancel() |
| | if pump_task: |
| | try: |
| | await asyncio.wait_for(pump_task, timeout=2.0) |
| | except asyncio.TimeoutError: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | except asyncio.CancelledError: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | except Exception: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| | finally: |
| | try: |
| | CHAT_TURN_SEM.release() |
| | except Exception: |
| | debug_traceback_print(app.state.cfg) |
| | pass |
| |
|
| | except WebSocketDisconnect: |
| | return |
| | finally: |
| | try: |
| | WS_CONN_SEM.release() |
| | except: |
| | pass |
| |
|