Upload 126 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +76 -0
- app/.DS_Store +0 -0
- app/api/pages/__init__.py +13 -0
- app/api/pages/admin.py +32 -0
- app/api/pages/public.py +51 -0
- app/api/v1/admin_api/__init__.py +15 -0
- app/api/v1/admin_api/cache.py +445 -0
- app/api/v1/admin_api/config.py +53 -0
- app/api/v1/admin_api/token.py +395 -0
- app/api/v1/chat.py +862 -0
- app/api/v1/files.py +69 -0
- app/api/v1/image.py +452 -0
- app/api/v1/models.py +28 -0
- app/api/v1/public_api/__init__.py +18 -0
- app/api/v1/public_api/imagine.py +505 -0
- app/api/v1/public_api/video.py +274 -0
- app/api/v1/public_api/voice.py +80 -0
- app/api/v1/response.py +81 -0
- app/api/v1/video.py +3 -0
- app/core/auth.py +198 -0
- app/core/batch.py +233 -0
- app/core/config.py +326 -0
- app/core/exceptions.py +232 -0
- app/core/logger.py +151 -0
- app/core/response_middleware.py +85 -0
- app/core/storage.py +1478 -0
- app/services/cf_refresh/README.md +49 -0
- app/services/cf_refresh/__init__.py +5 -0
- app/services/cf_refresh/config.py +41 -0
- app/services/cf_refresh/scheduler.py +98 -0
- app/services/cf_refresh/solver.py +122 -0
- app/services/grok/batch_services/assets.py +234 -0
- app/services/grok/batch_services/nsfw.py +112 -0
- app/services/grok/batch_services/usage.py +89 -0
- app/services/grok/defaults.py +34 -0
- app/services/grok/services/chat.py +1115 -0
- app/services/grok/services/image.py +794 -0
- app/services/grok/services/image_edit.py +567 -0
- app/services/grok/services/model.py +270 -0
- app/services/grok/services/responses.py +824 -0
- app/services/grok/services/video.py +688 -0
- app/services/grok/services/voice.py +31 -0
- app/services/grok/utils/cache.py +110 -0
- app/services/grok/utils/download.py +298 -0
- app/services/grok/utils/locks.py +86 -0
- app/services/grok/utils/process.py +152 -0
- app/services/grok/utils/response.py +144 -0
- app/services/grok/utils/retry.py +66 -0
- app/services/grok/utils/stream.py +46 -0
- app/services/grok/utils/tool_call.py +319 -0
Dockerfile
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13-alpine AS builder
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
TZ=Asia/Shanghai \
|
| 6 |
+
# 把 uv 包安装到系统 Python 环境
|
| 7 |
+
UV_PROJECT_ENVIRONMENT=/opt/venv
|
| 8 |
+
|
| 9 |
+
# 确保 uv 的 bin 目录
|
| 10 |
+
ENV PATH="$UV_PROJECT_ENVIRONMENT/bin:$PATH"
|
| 11 |
+
|
| 12 |
+
RUN apk add --no-cache \
|
| 13 |
+
tzdata \
|
| 14 |
+
ca-certificates \
|
| 15 |
+
build-base \
|
| 16 |
+
linux-headers \
|
| 17 |
+
libffi-dev \
|
| 18 |
+
openssl-dev \
|
| 19 |
+
curl-dev \
|
| 20 |
+
cargo \
|
| 21 |
+
rust
|
| 22 |
+
|
| 23 |
+
WORKDIR /app
|
| 24 |
+
|
| 25 |
+
# 安装 uv
|
| 26 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 27 |
+
|
| 28 |
+
COPY pyproject.toml uv.lock ./
|
| 29 |
+
|
| 30 |
+
RUN uv sync --frozen --no-dev --no-install-project \
|
| 31 |
+
&& find /opt/venv -type d -name "__pycache__" -prune -exec rm -rf {} + \
|
| 32 |
+
&& find /opt/venv -type f -name "*.pyc" -delete \
|
| 33 |
+
&& find /opt/venv -type d -name "tests" -prune -exec rm -rf {} + \
|
| 34 |
+
&& find /opt/venv -type d -name "test" -prune -exec rm -rf {} + \
|
| 35 |
+
&& find /opt/venv -type d -name "testing" -prune -exec rm -rf {} + \
|
| 36 |
+
&& find /opt/venv -type f -name "*.so" -exec strip --strip-unneeded {} + || true \
|
| 37 |
+
&& rm -rf /root/.cache /tmp/uv-cache
|
| 38 |
+
|
| 39 |
+
FROM python:3.13-alpine
|
| 40 |
+
|
| 41 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 42 |
+
PYTHONUNBUFFERED=1 \
|
| 43 |
+
TZ=Asia/Shanghai \
|
| 44 |
+
VIRTUAL_ENV=/opt/venv
|
| 45 |
+
|
| 46 |
+
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
| 47 |
+
|
| 48 |
+
RUN apk add --no-cache \
|
| 49 |
+
tzdata \
|
| 50 |
+
ca-certificates \
|
| 51 |
+
libffi \
|
| 52 |
+
openssl \
|
| 53 |
+
libgcc \
|
| 54 |
+
libstdc++ \
|
| 55 |
+
libcurl
|
| 56 |
+
|
| 57 |
+
WORKDIR /app
|
| 58 |
+
|
| 59 |
+
COPY --from=builder /opt/venv /opt/venv
|
| 60 |
+
|
| 61 |
+
COPY config.defaults.toml ./
|
| 62 |
+
COPY app ./app
|
| 63 |
+
COPY main.py ./
|
| 64 |
+
COPY scripts ./scripts
|
| 65 |
+
|
| 66 |
+
RUN mkdir -p /app/data /app/logs \
|
| 67 |
+
&& chmod +x /app/scripts/entrypoint.sh
|
| 68 |
+
|
| 69 |
+
RUN chmod +x /app/scripts/entrypoint.sh
|
| 70 |
+
RUN chmod +x /app/scripts/init_storage.sh
|
| 71 |
+
|
| 72 |
+
EXPOSE 7860
|
| 73 |
+
|
| 74 |
+
ENTRYPOINT ["/app/scripts/entrypoint.sh"]
|
| 75 |
+
|
| 76 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app/api/pages/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UI pages router."""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
|
| 5 |
+
from app.api.pages.admin import router as admin_router
|
| 6 |
+
from app.api.pages.public import router as public_router
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
|
| 10 |
+
router.include_router(public_router)
|
| 11 |
+
router.include_router(admin_router)
|
| 12 |
+
|
| 13 |
+
__all__ = ["router"]
|
app/api/pages/admin.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
from fastapi.responses import FileResponse, RedirectResponse
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
STATIC_DIR = Path(__file__).resolve().parents[2] / "static"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@router.get("/admin", include_in_schema=False)
|
| 11 |
+
async def admin_root():
|
| 12 |
+
return RedirectResponse(url="/admin/login")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.get("/admin/login", include_in_schema=False)
|
| 16 |
+
async def admin_login():
|
| 17 |
+
return FileResponse(STATIC_DIR / "admin/pages/login.html")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.get("/admin/config", include_in_schema=False)
|
| 21 |
+
async def admin_config():
|
| 22 |
+
return FileResponse(STATIC_DIR / "admin/pages/config.html")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@router.get("/admin/cache", include_in_schema=False)
|
| 26 |
+
async def admin_cache():
|
| 27 |
+
return FileResponse(STATIC_DIR / "admin/pages/cache.html")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@router.get("/admin/token", include_in_schema=False)
|
| 31 |
+
async def admin_token():
|
| 32 |
+
return FileResponse(STATIC_DIR / "admin/pages/token.html")
|
app/api/pages/public.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, HTTPException
|
| 4 |
+
from fastapi.responses import FileResponse, RedirectResponse
|
| 5 |
+
|
| 6 |
+
from app.core.auth import is_public_enabled
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
STATIC_DIR = Path(__file__).resolve().parents[2] / "static"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.get("/", include_in_schema=False)
|
| 13 |
+
async def root():
|
| 14 |
+
if is_public_enabled():
|
| 15 |
+
return RedirectResponse(url="/login")
|
| 16 |
+
return RedirectResponse(url="/admin/login")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@router.get("/login", include_in_schema=False)
|
| 20 |
+
async def public_login():
|
| 21 |
+
if not is_public_enabled():
|
| 22 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 23 |
+
return FileResponse(STATIC_DIR / "public/pages/login.html")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.get("/imagine", include_in_schema=False)
|
| 27 |
+
async def public_imagine():
|
| 28 |
+
if not is_public_enabled():
|
| 29 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 30 |
+
return FileResponse(STATIC_DIR / "public/pages/imagine.html")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@router.get("/voice", include_in_schema=False)
|
| 34 |
+
async def public_voice():
|
| 35 |
+
if not is_public_enabled():
|
| 36 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 37 |
+
return FileResponse(STATIC_DIR / "public/pages/voice.html")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.get("/video", include_in_schema=False)
|
| 41 |
+
async def public_video():
|
| 42 |
+
if not is_public_enabled():
|
| 43 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 44 |
+
return FileResponse(STATIC_DIR / "public/pages/video.html")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@router.get("/chat", include_in_schema=False)
|
| 48 |
+
async def public_chat():
|
| 49 |
+
if not is_public_enabled():
|
| 50 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 51 |
+
return FileResponse(STATIC_DIR / "public/pages/chat.html")
|
app/api/v1/admin_api/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Admin API router (app_key protected)."""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
|
| 5 |
+
from app.api.v1.admin_api.cache import router as cache_router
|
| 6 |
+
from app.api.v1.admin_api.config import router as config_router
|
| 7 |
+
from app.api.v1.admin_api.token import router as tokens_router
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
router.include_router(config_router)
|
| 12 |
+
router.include_router(tokens_router)
|
| 13 |
+
router.include_router(cache_router)
|
| 14 |
+
|
| 15 |
+
__all__ = ["router"]
|
app/api/v1/admin_api/cache.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
| 4 |
+
|
| 5 |
+
from app.core.auth import verify_app_key
|
| 6 |
+
from app.core.batch import create_task, expire_task
|
| 7 |
+
from app.services.grok.batch_services.assets import ListService, DeleteService
|
| 8 |
+
from app.services.token.manager import get_token_manager
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.get("/cache", dependencies=[Depends(verify_app_key)])
|
| 13 |
+
async def cache_stats(request: Request):
|
| 14 |
+
"""获取缓存统计"""
|
| 15 |
+
from app.services.grok.utils.cache import CacheService
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
cache_service = CacheService()
|
| 19 |
+
image_stats = cache_service.get_stats("image")
|
| 20 |
+
video_stats = cache_service.get_stats("video")
|
| 21 |
+
|
| 22 |
+
mgr = await get_token_manager()
|
| 23 |
+
pools = mgr.pools
|
| 24 |
+
accounts = []
|
| 25 |
+
for pool_name, pool in pools.items():
|
| 26 |
+
for info in pool.list():
|
| 27 |
+
raw_token = (
|
| 28 |
+
info.token[4:] if info.token.startswith("sso=") else info.token
|
| 29 |
+
)
|
| 30 |
+
masked = (
|
| 31 |
+
f"{raw_token[:8]}...{raw_token[-16:]}"
|
| 32 |
+
if len(raw_token) > 24
|
| 33 |
+
else raw_token
|
| 34 |
+
)
|
| 35 |
+
accounts.append(
|
| 36 |
+
{
|
| 37 |
+
"token": raw_token,
|
| 38 |
+
"token_masked": masked,
|
| 39 |
+
"pool": pool_name,
|
| 40 |
+
"status": info.status,
|
| 41 |
+
"last_asset_clear_at": info.last_asset_clear_at,
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
scope = request.query_params.get("scope")
|
| 46 |
+
selected_token = request.query_params.get("token")
|
| 47 |
+
tokens_param = request.query_params.get("tokens")
|
| 48 |
+
selected_tokens = []
|
| 49 |
+
if tokens_param:
|
| 50 |
+
selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()]
|
| 51 |
+
|
| 52 |
+
online_stats = {
|
| 53 |
+
"count": 0,
|
| 54 |
+
"status": "unknown",
|
| 55 |
+
"token": None,
|
| 56 |
+
"last_asset_clear_at": None,
|
| 57 |
+
}
|
| 58 |
+
online_details = []
|
| 59 |
+
account_map = {a["token"]: a for a in accounts}
|
| 60 |
+
if selected_tokens:
|
| 61 |
+
total = 0
|
| 62 |
+
raw_results = await ListService.fetch_assets_details(
|
| 63 |
+
selected_tokens,
|
| 64 |
+
account_map,
|
| 65 |
+
)
|
| 66 |
+
for token, res in raw_results.items():
|
| 67 |
+
if res.get("ok"):
|
| 68 |
+
data = res.get("data", {})
|
| 69 |
+
detail = data.get("detail")
|
| 70 |
+
total += data.get("count", 0)
|
| 71 |
+
else:
|
| 72 |
+
account = account_map.get(token)
|
| 73 |
+
detail = {
|
| 74 |
+
"token": token,
|
| 75 |
+
"token_masked": account["token_masked"] if account else token,
|
| 76 |
+
"count": 0,
|
| 77 |
+
"status": f"error: {res.get('error')}",
|
| 78 |
+
"last_asset_clear_at": account["last_asset_clear_at"]
|
| 79 |
+
if account
|
| 80 |
+
else None,
|
| 81 |
+
}
|
| 82 |
+
if detail:
|
| 83 |
+
online_details.append(detail)
|
| 84 |
+
online_stats = {
|
| 85 |
+
"count": total,
|
| 86 |
+
"status": "ok" if selected_tokens else "no_token",
|
| 87 |
+
"token": None,
|
| 88 |
+
"last_asset_clear_at": None,
|
| 89 |
+
}
|
| 90 |
+
scope = "selected"
|
| 91 |
+
elif scope == "all":
|
| 92 |
+
total = 0
|
| 93 |
+
tokens = list(dict.fromkeys([account["token"] for account in accounts]))
|
| 94 |
+
raw_results = await ListService.fetch_assets_details(
|
| 95 |
+
tokens,
|
| 96 |
+
account_map,
|
| 97 |
+
)
|
| 98 |
+
for token, res in raw_results.items():
|
| 99 |
+
if res.get("ok"):
|
| 100 |
+
data = res.get("data", {})
|
| 101 |
+
detail = data.get("detail")
|
| 102 |
+
total += data.get("count", 0)
|
| 103 |
+
else:
|
| 104 |
+
account = account_map.get(token)
|
| 105 |
+
detail = {
|
| 106 |
+
"token": token,
|
| 107 |
+
"token_masked": account["token_masked"] if account else token,
|
| 108 |
+
"count": 0,
|
| 109 |
+
"status": f"error: {res.get('error')}",
|
| 110 |
+
"last_asset_clear_at": account["last_asset_clear_at"]
|
| 111 |
+
if account
|
| 112 |
+
else None,
|
| 113 |
+
}
|
| 114 |
+
if detail:
|
| 115 |
+
online_details.append(detail)
|
| 116 |
+
online_stats = {
|
| 117 |
+
"count": total,
|
| 118 |
+
"status": "ok" if accounts else "no_token",
|
| 119 |
+
"token": None,
|
| 120 |
+
"last_asset_clear_at": None,
|
| 121 |
+
}
|
| 122 |
+
else:
|
| 123 |
+
token = selected_token
|
| 124 |
+
if token:
|
| 125 |
+
raw_results = await ListService.fetch_assets_details(
|
| 126 |
+
[token],
|
| 127 |
+
account_map,
|
| 128 |
+
)
|
| 129 |
+
res = raw_results.get(token, {})
|
| 130 |
+
data = res.get("data", {})
|
| 131 |
+
detail = data.get("detail") if res.get("ok") else None
|
| 132 |
+
if detail:
|
| 133 |
+
online_stats = {
|
| 134 |
+
"count": data.get("count", 0),
|
| 135 |
+
"status": detail.get("status", "ok"),
|
| 136 |
+
"token": detail.get("token"),
|
| 137 |
+
"token_masked": detail.get("token_masked"),
|
| 138 |
+
"last_asset_clear_at": detail.get("last_asset_clear_at"),
|
| 139 |
+
}
|
| 140 |
+
else:
|
| 141 |
+
match = next((a for a in accounts if a["token"] == token), None)
|
| 142 |
+
online_stats = {
|
| 143 |
+
"count": 0,
|
| 144 |
+
"status": f"error: {res.get('error')}",
|
| 145 |
+
"token": token,
|
| 146 |
+
"token_masked": match["token_masked"] if match else token,
|
| 147 |
+
"last_asset_clear_at": match["last_asset_clear_at"]
|
| 148 |
+
if match
|
| 149 |
+
else None,
|
| 150 |
+
}
|
| 151 |
+
else:
|
| 152 |
+
online_stats = {
|
| 153 |
+
"count": 0,
|
| 154 |
+
"status": "not_loaded",
|
| 155 |
+
"token": None,
|
| 156 |
+
"last_asset_clear_at": None,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
response = {
|
| 160 |
+
"local_image": image_stats,
|
| 161 |
+
"local_video": video_stats,
|
| 162 |
+
"online": online_stats,
|
| 163 |
+
"online_accounts": accounts,
|
| 164 |
+
"online_scope": scope or "none",
|
| 165 |
+
"online_details": online_details,
|
| 166 |
+
}
|
| 167 |
+
return response
|
| 168 |
+
except Exception as e:
|
| 169 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@router.get("/cache/list", dependencies=[Depends(verify_app_key)])
|
| 173 |
+
async def list_local(
|
| 174 |
+
cache_type: str = "image",
|
| 175 |
+
type_: str = Query(default=None, alias="type"),
|
| 176 |
+
page: int = 1,
|
| 177 |
+
page_size: int = 1000,
|
| 178 |
+
):
|
| 179 |
+
"""列出本地缓存文件"""
|
| 180 |
+
from app.services.grok.utils.cache import CacheService
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
if type_:
|
| 184 |
+
cache_type = type_
|
| 185 |
+
cache_service = CacheService()
|
| 186 |
+
result = cache_service.list_files(cache_type, page, page_size)
|
| 187 |
+
return {"status": "success", **result}
|
| 188 |
+
except Exception as e:
|
| 189 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@router.post("/cache/clear", dependencies=[Depends(verify_app_key)])
|
| 193 |
+
async def clear_local(data: dict):
|
| 194 |
+
"""清理本地缓存"""
|
| 195 |
+
from app.services.grok.utils.cache import CacheService
|
| 196 |
+
|
| 197 |
+
cache_type = data.get("type", "image")
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
cache_service = CacheService()
|
| 201 |
+
result = cache_service.clear(cache_type)
|
| 202 |
+
return {"status": "success", "result": result}
|
| 203 |
+
except Exception as e:
|
| 204 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@router.post("/cache/item/delete", dependencies=[Depends(verify_app_key)])
|
| 208 |
+
async def delete_local_item(data: dict):
|
| 209 |
+
"""删除单个本地缓存文件"""
|
| 210 |
+
from app.services.grok.utils.cache import CacheService
|
| 211 |
+
|
| 212 |
+
cache_type = data.get("type", "image")
|
| 213 |
+
name = data.get("name")
|
| 214 |
+
if not name:
|
| 215 |
+
raise HTTPException(status_code=400, detail="Missing file name")
|
| 216 |
+
try:
|
| 217 |
+
cache_service = CacheService()
|
| 218 |
+
result = cache_service.delete_file(cache_type, name)
|
| 219 |
+
return {"status": "success", "result": result}
|
| 220 |
+
except Exception as e:
|
| 221 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@router.post("/cache/online/clear", dependencies=[Depends(verify_app_key)])
|
| 225 |
+
async def clear_online(data: dict):
|
| 226 |
+
"""清理在线缓存"""
|
| 227 |
+
try:
|
| 228 |
+
mgr = await get_token_manager()
|
| 229 |
+
tokens = data.get("tokens")
|
| 230 |
+
|
| 231 |
+
if isinstance(tokens, list):
|
| 232 |
+
token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()]
|
| 233 |
+
if not token_list:
|
| 234 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 235 |
+
|
| 236 |
+
token_list = list(dict.fromkeys(token_list))
|
| 237 |
+
|
| 238 |
+
results = {}
|
| 239 |
+
raw_results = await DeleteService.clear_assets(
|
| 240 |
+
token_list,
|
| 241 |
+
mgr,
|
| 242 |
+
)
|
| 243 |
+
for token, res in raw_results.items():
|
| 244 |
+
if res.get("ok"):
|
| 245 |
+
results[token] = res.get("data", {})
|
| 246 |
+
else:
|
| 247 |
+
results[token] = {"status": "error", "error": res.get("error")}
|
| 248 |
+
|
| 249 |
+
return {"status": "success", "results": results}
|
| 250 |
+
|
| 251 |
+
token = data.get("token") or mgr.get_token()
|
| 252 |
+
if not token:
|
| 253 |
+
raise HTTPException(
|
| 254 |
+
status_code=400, detail="No available token to perform cleanup"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
raw_results = await DeleteService.clear_assets(
|
| 258 |
+
[token],
|
| 259 |
+
mgr,
|
| 260 |
+
)
|
| 261 |
+
res = raw_results.get(token, {})
|
| 262 |
+
data = res.get("data", {})
|
| 263 |
+
if res.get("ok") and data.get("status") == "success":
|
| 264 |
+
return {"status": "success", "result": data.get("result")}
|
| 265 |
+
return {"status": "error", "error": data.get("error") or res.get("error")}
|
| 266 |
+
except Exception as e:
|
| 267 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@router.post("/cache/online/clear/async", dependencies=[Depends(verify_app_key)])
|
| 271 |
+
async def clear_online_async(data: dict):
|
| 272 |
+
"""清理在线缓存(异步批量 + SSE 进度)"""
|
| 273 |
+
mgr = await get_token_manager()
|
| 274 |
+
tokens = data.get("tokens")
|
| 275 |
+
if not isinstance(tokens, list):
|
| 276 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 277 |
+
|
| 278 |
+
token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()]
|
| 279 |
+
if not token_list:
|
| 280 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 281 |
+
|
| 282 |
+
task = create_task(len(token_list))
|
| 283 |
+
|
| 284 |
+
async def _run():
|
| 285 |
+
try:
|
| 286 |
+
async def _on_item(item: str, res: dict):
|
| 287 |
+
ok = bool(res.get("data", {}).get("ok"))
|
| 288 |
+
task.record(ok)
|
| 289 |
+
|
| 290 |
+
raw_results = await DeleteService.clear_assets(
|
| 291 |
+
token_list,
|
| 292 |
+
mgr,
|
| 293 |
+
include_ok=True,
|
| 294 |
+
on_item=_on_item,
|
| 295 |
+
should_cancel=lambda: task.cancelled,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if task.cancelled:
|
| 299 |
+
task.finish_cancelled()
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
results = {}
|
| 303 |
+
ok_count = 0
|
| 304 |
+
fail_count = 0
|
| 305 |
+
for token, res in raw_results.items():
|
| 306 |
+
data = res.get("data", {})
|
| 307 |
+
if data.get("ok"):
|
| 308 |
+
ok_count += 1
|
| 309 |
+
results[token] = {"status": "success", "result": data.get("result")}
|
| 310 |
+
else:
|
| 311 |
+
fail_count += 1
|
| 312 |
+
results[token] = {"status": "error", "error": data.get("error")}
|
| 313 |
+
|
| 314 |
+
result = {
|
| 315 |
+
"status": "success",
|
| 316 |
+
"summary": {
|
| 317 |
+
"total": len(token_list),
|
| 318 |
+
"ok": ok_count,
|
| 319 |
+
"fail": fail_count,
|
| 320 |
+
},
|
| 321 |
+
"results": results,
|
| 322 |
+
}
|
| 323 |
+
task.finish(result)
|
| 324 |
+
except Exception as e:
|
| 325 |
+
task.fail_task(str(e))
|
| 326 |
+
finally:
|
| 327 |
+
import asyncio
|
| 328 |
+
asyncio.create_task(expire_task(task.id, 300))
|
| 329 |
+
|
| 330 |
+
import asyncio
|
| 331 |
+
asyncio.create_task(_run())
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
"status": "success",
|
| 335 |
+
"task_id": task.id,
|
| 336 |
+
"total": len(token_list),
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@router.post("/cache/online/load/async", dependencies=[Depends(verify_app_key)])
|
| 341 |
+
async def load_cache_async(data: dict):
|
| 342 |
+
"""在线资产统计(异步批量 + SSE 进度)"""
|
| 343 |
+
from app.services.grok.utils.cache import CacheService
|
| 344 |
+
|
| 345 |
+
mgr = await get_token_manager()
|
| 346 |
+
|
| 347 |
+
accounts = []
|
| 348 |
+
for pool_name, pool in mgr.pools.items():
|
| 349 |
+
for info in pool.list():
|
| 350 |
+
raw_token = info.token[4:] if info.token.startswith("sso=") else info.token
|
| 351 |
+
masked = (
|
| 352 |
+
f"{raw_token[:8]}...{raw_token[-16:]}"
|
| 353 |
+
if len(raw_token) > 24
|
| 354 |
+
else raw_token
|
| 355 |
+
)
|
| 356 |
+
accounts.append(
|
| 357 |
+
{
|
| 358 |
+
"token": raw_token,
|
| 359 |
+
"token_masked": masked,
|
| 360 |
+
"pool": pool_name,
|
| 361 |
+
"status": info.status,
|
| 362 |
+
"last_asset_clear_at": info.last_asset_clear_at,
|
| 363 |
+
}
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
account_map = {a["token"]: a for a in accounts}
|
| 367 |
+
|
| 368 |
+
tokens = data.get("tokens")
|
| 369 |
+
scope = data.get("scope")
|
| 370 |
+
selected_tokens: List[str] = []
|
| 371 |
+
if isinstance(tokens, list):
|
| 372 |
+
selected_tokens = [str(t).strip() for t in tokens if str(t).strip()]
|
| 373 |
+
|
| 374 |
+
if not selected_tokens and scope == "all":
|
| 375 |
+
selected_tokens = [account["token"] for account in accounts]
|
| 376 |
+
scope = "all"
|
| 377 |
+
elif selected_tokens:
|
| 378 |
+
scope = "selected"
|
| 379 |
+
else:
|
| 380 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 381 |
+
|
| 382 |
+
task = create_task(len(selected_tokens))
|
| 383 |
+
|
| 384 |
+
async def _run():
|
| 385 |
+
try:
|
| 386 |
+
cache_service = CacheService()
|
| 387 |
+
image_stats = cache_service.get_stats("image")
|
| 388 |
+
video_stats = cache_service.get_stats("video")
|
| 389 |
+
|
| 390 |
+
async def _on_item(item: str, res: dict):
|
| 391 |
+
ok = bool(res.get("data", {}).get("ok"))
|
| 392 |
+
task.record(ok)
|
| 393 |
+
|
| 394 |
+
raw_results = await ListService.fetch_assets_details(
|
| 395 |
+
selected_tokens,
|
| 396 |
+
account_map,
|
| 397 |
+
include_ok=True,
|
| 398 |
+
on_item=_on_item,
|
| 399 |
+
should_cancel=lambda: task.cancelled,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
if task.cancelled:
|
| 403 |
+
task.finish_cancelled()
|
| 404 |
+
return
|
| 405 |
+
|
| 406 |
+
online_details = []
|
| 407 |
+
total = 0
|
| 408 |
+
for token, res in raw_results.items():
|
| 409 |
+
data = res.get("data", {})
|
| 410 |
+
detail = data.get("detail")
|
| 411 |
+
if detail:
|
| 412 |
+
online_details.append(detail)
|
| 413 |
+
total += data.get("count", 0)
|
| 414 |
+
|
| 415 |
+
online_stats = {
|
| 416 |
+
"count": total,
|
| 417 |
+
"status": "ok" if selected_tokens else "no_token",
|
| 418 |
+
"token": None,
|
| 419 |
+
"last_asset_clear_at": None,
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
result = {
|
| 423 |
+
"local_image": image_stats,
|
| 424 |
+
"local_video": video_stats,
|
| 425 |
+
"online": online_stats,
|
| 426 |
+
"online_accounts": accounts,
|
| 427 |
+
"online_scope": scope or "none",
|
| 428 |
+
"online_details": online_details,
|
| 429 |
+
}
|
| 430 |
+
task.finish(result)
|
| 431 |
+
except Exception as e:
|
| 432 |
+
task.fail_task(str(e))
|
| 433 |
+
finally:
|
| 434 |
+
import asyncio
|
| 435 |
+
asyncio.create_task(expire_task(task.id, 300))
|
| 436 |
+
|
| 437 |
+
import asyncio
|
| 438 |
+
asyncio.create_task(_run())
|
| 439 |
+
|
| 440 |
+
return {
|
| 441 |
+
"status": "success",
|
| 442 |
+
"task_id": task.id,
|
| 443 |
+
"total": len(selected_tokens),
|
| 444 |
+
}
|
| 445 |
+
|
app/api/v1/admin_api/config.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 4 |
+
|
| 5 |
+
from app.core.auth import verify_app_key
|
| 6 |
+
from app.core.config import config
|
| 7 |
+
from app.core.storage import get_storage as resolve_storage, LocalStorage, RedisStorage, SQLStorage
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.get("/verify", dependencies=[Depends(verify_app_key)])
|
| 13 |
+
async def admin_verify():
|
| 14 |
+
"""验证后台访问密钥(app_key)"""
|
| 15 |
+
return {"status": "success"}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.get("/config", dependencies=[Depends(verify_app_key)])
|
| 19 |
+
async def get_config():
|
| 20 |
+
"""获取当前配置"""
|
| 21 |
+
# 暴露原始配置字典
|
| 22 |
+
return config._config
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@router.post("/config", dependencies=[Depends(verify_app_key)])
|
| 26 |
+
async def update_config(data: dict):
|
| 27 |
+
"""更新配置"""
|
| 28 |
+
try:
|
| 29 |
+
await config.update(data)
|
| 30 |
+
return {"status": "success", "message": "配置已更新"}
|
| 31 |
+
except Exception as e:
|
| 32 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@router.get("/storage", dependencies=[Depends(verify_app_key)])
|
| 36 |
+
async def get_storage_mode():
|
| 37 |
+
"""获取当前存储模式"""
|
| 38 |
+
storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower()
|
| 39 |
+
if not storage_type:
|
| 40 |
+
storage = resolve_storage()
|
| 41 |
+
if isinstance(storage, LocalStorage):
|
| 42 |
+
storage_type = "local"
|
| 43 |
+
elif isinstance(storage, RedisStorage):
|
| 44 |
+
storage_type = "redis"
|
| 45 |
+
elif isinstance(storage, SQLStorage):
|
| 46 |
+
storage_type = {
|
| 47 |
+
"mysql": "mysql",
|
| 48 |
+
"mariadb": "mysql",
|
| 49 |
+
"postgres": "pgsql",
|
| 50 |
+
"postgresql": "pgsql",
|
| 51 |
+
"pgsql": "pgsql",
|
| 52 |
+
}.get(storage.dialect, storage.dialect)
|
| 53 |
+
return {"type": storage_type or "local"}
|
app/api/v1/admin_api/token.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import orjson
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
| 5 |
+
from fastapi.responses import StreamingResponse
|
| 6 |
+
|
| 7 |
+
from app.core.auth import get_app_key, verify_app_key
|
| 8 |
+
from app.core.batch import create_task, expire_task, get_task
|
| 9 |
+
from app.core.logger import logger
|
| 10 |
+
from app.core.storage import get_storage
|
| 11 |
+
from app.services.grok.batch_services.usage import UsageService
|
| 12 |
+
from app.services.grok.batch_services.nsfw import NSFWService
|
| 13 |
+
from app.services.token.manager import get_token_manager
|
| 14 |
+
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.get("/tokens", dependencies=[Depends(verify_app_key)])
|
| 19 |
+
async def get_tokens():
|
| 20 |
+
"""获取所有 Token"""
|
| 21 |
+
storage = get_storage()
|
| 22 |
+
tokens = await storage.load_tokens()
|
| 23 |
+
return tokens or {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.post("/tokens", dependencies=[Depends(verify_app_key)])
|
| 27 |
+
async def update_tokens(data: dict):
|
| 28 |
+
"""更新 Token 信息"""
|
| 29 |
+
storage = get_storage()
|
| 30 |
+
try:
|
| 31 |
+
from app.services.token.models import TokenInfo
|
| 32 |
+
|
| 33 |
+
async with storage.acquire_lock("tokens_save", timeout=10):
|
| 34 |
+
existing = await storage.load_tokens() or {}
|
| 35 |
+
normalized = {}
|
| 36 |
+
allowed_fields = set(TokenInfo.model_fields.keys())
|
| 37 |
+
existing_map = {}
|
| 38 |
+
for pool_name, tokens in existing.items():
|
| 39 |
+
if not isinstance(tokens, list):
|
| 40 |
+
continue
|
| 41 |
+
pool_map = {}
|
| 42 |
+
for item in tokens:
|
| 43 |
+
if isinstance(item, str):
|
| 44 |
+
token_data = {"token": item}
|
| 45 |
+
elif isinstance(item, dict):
|
| 46 |
+
token_data = dict(item)
|
| 47 |
+
else:
|
| 48 |
+
continue
|
| 49 |
+
raw_token = token_data.get("token")
|
| 50 |
+
if isinstance(raw_token, str) and raw_token.startswith("sso="):
|
| 51 |
+
token_data["token"] = raw_token[4:]
|
| 52 |
+
token_key = token_data.get("token")
|
| 53 |
+
if isinstance(token_key, str):
|
| 54 |
+
pool_map[token_key] = token_data
|
| 55 |
+
existing_map[pool_name] = pool_map
|
| 56 |
+
for pool_name, tokens in (data or {}).items():
|
| 57 |
+
if not isinstance(tokens, list):
|
| 58 |
+
continue
|
| 59 |
+
pool_list = []
|
| 60 |
+
for item in tokens:
|
| 61 |
+
if isinstance(item, str):
|
| 62 |
+
token_data = {"token": item}
|
| 63 |
+
elif isinstance(item, dict):
|
| 64 |
+
token_data = dict(item)
|
| 65 |
+
else:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
raw_token = token_data.get("token")
|
| 69 |
+
if isinstance(raw_token, str) and raw_token.startswith("sso="):
|
| 70 |
+
token_data["token"] = raw_token[4:]
|
| 71 |
+
|
| 72 |
+
base = existing_map.get(pool_name, {}).get(
|
| 73 |
+
token_data.get("token"), {}
|
| 74 |
+
)
|
| 75 |
+
merged = dict(base)
|
| 76 |
+
merged.update(token_data)
|
| 77 |
+
if merged.get("tags") is None:
|
| 78 |
+
merged["tags"] = []
|
| 79 |
+
|
| 80 |
+
filtered = {k: v for k, v in merged.items() if k in allowed_fields}
|
| 81 |
+
try:
|
| 82 |
+
info = TokenInfo(**filtered)
|
| 83 |
+
pool_list.append(info.model_dump())
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"Skip invalid token in pool '{pool_name}': {e}")
|
| 86 |
+
continue
|
| 87 |
+
normalized[pool_name] = pool_list
|
| 88 |
+
|
| 89 |
+
await storage.save_tokens(normalized)
|
| 90 |
+
mgr = await get_token_manager()
|
| 91 |
+
await mgr.reload()
|
| 92 |
+
return {"status": "success", "message": "Token 已更新"}
|
| 93 |
+
except Exception as e:
|
| 94 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)])
|
| 98 |
+
async def refresh_tokens(data: dict):
|
| 99 |
+
"""刷新 Token 状态"""
|
| 100 |
+
try:
|
| 101 |
+
mgr = await get_token_manager()
|
| 102 |
+
tokens = []
|
| 103 |
+
if isinstance(data.get("token"), str) and data["token"].strip():
|
| 104 |
+
tokens.append(data["token"].strip())
|
| 105 |
+
if isinstance(data.get("tokens"), list):
|
| 106 |
+
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
|
| 107 |
+
|
| 108 |
+
if not tokens:
|
| 109 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 110 |
+
|
| 111 |
+
unique_tokens = list(dict.fromkeys(tokens))
|
| 112 |
+
|
| 113 |
+
raw_results = await UsageService.batch(
|
| 114 |
+
unique_tokens,
|
| 115 |
+
mgr,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
results = {}
|
| 119 |
+
for token, res in raw_results.items():
|
| 120 |
+
if res.get("ok"):
|
| 121 |
+
results[token] = res.get("data", False)
|
| 122 |
+
else:
|
| 123 |
+
results[token] = False
|
| 124 |
+
|
| 125 |
+
response = {"status": "success", "results": results}
|
| 126 |
+
return response
|
| 127 |
+
except Exception as e:
|
| 128 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)])
|
| 132 |
+
async def refresh_tokens_async(data: dict):
|
| 133 |
+
"""刷新 Token 状态(异步批量 + SSE 进度)"""
|
| 134 |
+
mgr = await get_token_manager()
|
| 135 |
+
tokens = []
|
| 136 |
+
if isinstance(data.get("token"), str) and data["token"].strip():
|
| 137 |
+
tokens.append(data["token"].strip())
|
| 138 |
+
if isinstance(data.get("tokens"), list):
|
| 139 |
+
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
|
| 140 |
+
|
| 141 |
+
if not tokens:
|
| 142 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 143 |
+
|
| 144 |
+
unique_tokens = list(dict.fromkeys(tokens))
|
| 145 |
+
|
| 146 |
+
task = create_task(len(unique_tokens))
|
| 147 |
+
|
| 148 |
+
async def _run():
|
| 149 |
+
try:
|
| 150 |
+
|
| 151 |
+
async def _on_item(item: str, res: dict):
|
| 152 |
+
task.record(bool(res.get("ok")))
|
| 153 |
+
|
| 154 |
+
raw_results = await UsageService.batch(
|
| 155 |
+
unique_tokens,
|
| 156 |
+
mgr,
|
| 157 |
+
on_item=_on_item,
|
| 158 |
+
should_cancel=lambda: task.cancelled,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if task.cancelled:
|
| 162 |
+
task.finish_cancelled()
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
results: dict[str, bool] = {}
|
| 166 |
+
ok_count = 0
|
| 167 |
+
fail_count = 0
|
| 168 |
+
for token, res in raw_results.items():
|
| 169 |
+
if res.get("ok") and res.get("data") is True:
|
| 170 |
+
ok_count += 1
|
| 171 |
+
results[token] = True
|
| 172 |
+
else:
|
| 173 |
+
fail_count += 1
|
| 174 |
+
results[token] = False
|
| 175 |
+
|
| 176 |
+
await mgr._save(force=True)
|
| 177 |
+
|
| 178 |
+
result = {
|
| 179 |
+
"status": "success",
|
| 180 |
+
"summary": {
|
| 181 |
+
"total": len(unique_tokens),
|
| 182 |
+
"ok": ok_count,
|
| 183 |
+
"fail": fail_count,
|
| 184 |
+
},
|
| 185 |
+
"results": results,
|
| 186 |
+
}
|
| 187 |
+
task.finish(result)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
task.fail_task(str(e))
|
| 190 |
+
finally:
|
| 191 |
+
import asyncio
|
| 192 |
+
asyncio.create_task(expire_task(task.id, 300))
|
| 193 |
+
|
| 194 |
+
import asyncio
|
| 195 |
+
asyncio.create_task(_run())
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"status": "success",
|
| 199 |
+
"task_id": task.id,
|
| 200 |
+
"total": len(unique_tokens),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@router.get("/batch/{task_id}/stream")
|
| 205 |
+
async def batch_stream(task_id: str, request: Request):
|
| 206 |
+
app_key = get_app_key()
|
| 207 |
+
if app_key:
|
| 208 |
+
key = request.query_params.get("app_key")
|
| 209 |
+
if key != app_key:
|
| 210 |
+
raise HTTPException(status_code=401, detail="Invalid authentication token")
|
| 211 |
+
task = get_task(task_id)
|
| 212 |
+
if not task:
|
| 213 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 214 |
+
|
| 215 |
+
async def event_stream():
|
| 216 |
+
queue = task.attach()
|
| 217 |
+
try:
|
| 218 |
+
yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n"
|
| 219 |
+
|
| 220 |
+
final = task.final_event()
|
| 221 |
+
if final:
|
| 222 |
+
yield f"data: {orjson.dumps(final).decode()}\n\n"
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
while True:
|
| 226 |
+
try:
|
| 227 |
+
event = await asyncio.wait_for(queue.get(), timeout=15)
|
| 228 |
+
except asyncio.TimeoutError:
|
| 229 |
+
yield ": ping\n\n"
|
| 230 |
+
final = task.final_event()
|
| 231 |
+
if final:
|
| 232 |
+
yield f"data: {orjson.dumps(final).decode()}\n\n"
|
| 233 |
+
return
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
yield f"data: {orjson.dumps(event).decode()}\n\n"
|
| 237 |
+
if event.get("type") in ("done", "error", "cancelled"):
|
| 238 |
+
return
|
| 239 |
+
finally:
|
| 240 |
+
task.detach(queue)
|
| 241 |
+
|
| 242 |
+
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)])
|
| 246 |
+
async def batch_cancel(task_id: str):
|
| 247 |
+
task = get_task(task_id)
|
| 248 |
+
if not task:
|
| 249 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 250 |
+
task.cancel()
|
| 251 |
+
return {"status": "success"}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)])
|
| 255 |
+
async def enable_nsfw(data: dict):
|
| 256 |
+
"""批量开启 NSFW (Unhinged) 模式"""
|
| 257 |
+
try:
|
| 258 |
+
mgr = await get_token_manager()
|
| 259 |
+
|
| 260 |
+
tokens = []
|
| 261 |
+
if isinstance(data.get("token"), str) and data["token"].strip():
|
| 262 |
+
tokens.append(data["token"].strip())
|
| 263 |
+
if isinstance(data.get("tokens"), list):
|
| 264 |
+
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
|
| 265 |
+
|
| 266 |
+
if not tokens:
|
| 267 |
+
for pool_name, pool in mgr.pools.items():
|
| 268 |
+
for info in pool.list():
|
| 269 |
+
raw = (
|
| 270 |
+
info.token[4:] if info.token.startswith("sso=") else info.token
|
| 271 |
+
)
|
| 272 |
+
tokens.append(raw)
|
| 273 |
+
|
| 274 |
+
if not tokens:
|
| 275 |
+
raise HTTPException(status_code=400, detail="No tokens available")
|
| 276 |
+
|
| 277 |
+
unique_tokens = list(dict.fromkeys(tokens))
|
| 278 |
+
|
| 279 |
+
raw_results = await NSFWService.batch(
|
| 280 |
+
unique_tokens,
|
| 281 |
+
mgr,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
results = {}
|
| 285 |
+
ok_count = 0
|
| 286 |
+
fail_count = 0
|
| 287 |
+
|
| 288 |
+
for token, res in raw_results.items():
|
| 289 |
+
masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token
|
| 290 |
+
if res.get("ok") and res.get("data", {}).get("success"):
|
| 291 |
+
ok_count += 1
|
| 292 |
+
results[masked] = res.get("data", {})
|
| 293 |
+
else:
|
| 294 |
+
fail_count += 1
|
| 295 |
+
results[masked] = res.get("data") or {"error": res.get("error")}
|
| 296 |
+
|
| 297 |
+
response = {
|
| 298 |
+
"status": "success",
|
| 299 |
+
"summary": {
|
| 300 |
+
"total": len(unique_tokens),
|
| 301 |
+
"ok": ok_count,
|
| 302 |
+
"fail": fail_count,
|
| 303 |
+
},
|
| 304 |
+
"results": results,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
return response
|
| 308 |
+
|
| 309 |
+
except HTTPException:
|
| 310 |
+
raise
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"Enable NSFW failed: {e}")
|
| 313 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
@router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)])
|
| 317 |
+
async def enable_nsfw_async(data: dict):
|
| 318 |
+
"""批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)"""
|
| 319 |
+
mgr = await get_token_manager()
|
| 320 |
+
|
| 321 |
+
tokens = []
|
| 322 |
+
if isinstance(data.get("token"), str) and data["token"].strip():
|
| 323 |
+
tokens.append(data["token"].strip())
|
| 324 |
+
if isinstance(data.get("tokens"), list):
|
| 325 |
+
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
|
| 326 |
+
|
| 327 |
+
if not tokens:
|
| 328 |
+
for pool_name, pool in mgr.pools.items():
|
| 329 |
+
for info in pool.list():
|
| 330 |
+
raw = info.token[4:] if info.token.startswith("sso=") else info.token
|
| 331 |
+
tokens.append(raw)
|
| 332 |
+
|
| 333 |
+
if not tokens:
|
| 334 |
+
raise HTTPException(status_code=400, detail="No tokens available")
|
| 335 |
+
|
| 336 |
+
unique_tokens = list(dict.fromkeys(tokens))
|
| 337 |
+
|
| 338 |
+
task = create_task(len(unique_tokens))
|
| 339 |
+
|
| 340 |
+
async def _run():
|
| 341 |
+
try:
|
| 342 |
+
|
| 343 |
+
async def _on_item(item: str, res: dict):
|
| 344 |
+
ok = bool(res.get("ok") and res.get("data", {}).get("success"))
|
| 345 |
+
task.record(ok)
|
| 346 |
+
|
| 347 |
+
raw_results = await NSFWService.batch(
|
| 348 |
+
unique_tokens,
|
| 349 |
+
mgr,
|
| 350 |
+
on_item=_on_item,
|
| 351 |
+
should_cancel=lambda: task.cancelled,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if task.cancelled:
|
| 355 |
+
task.finish_cancelled()
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
results = {}
|
| 359 |
+
ok_count = 0
|
| 360 |
+
fail_count = 0
|
| 361 |
+
for token, res in raw_results.items():
|
| 362 |
+
masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token
|
| 363 |
+
if res.get("ok") and res.get("data", {}).get("success"):
|
| 364 |
+
ok_count += 1
|
| 365 |
+
results[masked] = res.get("data", {})
|
| 366 |
+
else:
|
| 367 |
+
fail_count += 1
|
| 368 |
+
results[masked] = res.get("data") or {"error": res.get("error")}
|
| 369 |
+
|
| 370 |
+
await mgr._save(force=True)
|
| 371 |
+
|
| 372 |
+
result = {
|
| 373 |
+
"status": "success",
|
| 374 |
+
"summary": {
|
| 375 |
+
"total": len(unique_tokens),
|
| 376 |
+
"ok": ok_count,
|
| 377 |
+
"fail": fail_count,
|
| 378 |
+
},
|
| 379 |
+
"results": results,
|
| 380 |
+
}
|
| 381 |
+
task.finish(result)
|
| 382 |
+
except Exception as e:
|
| 383 |
+
task.fail_task(str(e))
|
| 384 |
+
finally:
|
| 385 |
+
import asyncio
|
| 386 |
+
asyncio.create_task(expire_task(task.id, 300))
|
| 387 |
+
|
| 388 |
+
import asyncio
|
| 389 |
+
asyncio.create_task(_run())
|
| 390 |
+
|
| 391 |
+
return {
|
| 392 |
+
"status": "success",
|
| 393 |
+
"task_id": task.id,
|
| 394 |
+
"total": len(unique_tokens),
|
| 395 |
+
}
|
app/api/v1/chat.py
ADDED
|
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Completions API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union
|
| 6 |
+
import base64
|
| 7 |
+
import binascii
|
| 8 |
+
import time
|
| 9 |
+
import uuid
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter
|
| 12 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
import orjson
|
| 15 |
+
|
| 16 |
+
from app.services.grok.services.chat import ChatService
|
| 17 |
+
from app.services.grok.services.image import ImageGenerationService
|
| 18 |
+
from app.services.grok.services.image_edit import ImageEditService
|
| 19 |
+
from app.services.grok.services.model import ModelService
|
| 20 |
+
from app.services.grok.services.video import VideoService
|
| 21 |
+
from app.services.grok.utils.response import make_chat_response
|
| 22 |
+
from app.services.token import get_token_manager
|
| 23 |
+
from app.core.config import get_config
|
| 24 |
+
from app.core.exceptions import ValidationException, AppException, ErrorType
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MessageItem(BaseModel):
|
| 28 |
+
"""消息项"""
|
| 29 |
+
|
| 30 |
+
role: str
|
| 31 |
+
content: Optional[Union[str, Dict[str, Any], List[Dict[str, Any]]]]
|
| 32 |
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
| 33 |
+
tool_call_id: Optional[str] = None
|
| 34 |
+
name: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VideoConfig(BaseModel):
|
| 38 |
+
"""视频生成配置"""
|
| 39 |
+
|
| 40 |
+
aspect_ratio: Optional[str] = Field("3:2", description="视频比例: 1280x720(16:9), 720x1280(9:16), 1792x1024(3:2), 1024x1792(2:3), 1024x1024(1:1)")
|
| 41 |
+
video_length: Optional[int] = Field(6, description="视频时长(秒): 6 / 10 / 15")
|
| 42 |
+
resolution_name: Optional[str] = Field("480p", description="视频分辨率: 480p, 720p")
|
| 43 |
+
preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ImageConfig(BaseModel):
|
| 47 |
+
"""图片生成配置"""
|
| 48 |
+
|
| 49 |
+
n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
|
| 50 |
+
size: Optional[str] = Field("1024x1024", description="图片尺寸")
|
| 51 |
+
response_format: Optional[str] = Field(None, description="响应格式")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ChatCompletionRequest(BaseModel):
|
| 55 |
+
"""Chat Completions 请求"""
|
| 56 |
+
|
| 57 |
+
model: str = Field(..., description="模型名称")
|
| 58 |
+
messages: List[MessageItem] = Field(..., description="消息数组")
|
| 59 |
+
stream: Optional[bool] = Field(None, description="是否流式输出")
|
| 60 |
+
reasoning_effort: Optional[str] = Field(None, description="推理强度: none/minimal/low/medium/high/xhigh")
|
| 61 |
+
temperature: Optional[float] = Field(0.8, description="采样温度: 0-2")
|
| 62 |
+
top_p: Optional[float] = Field(0.95, description="nucleus 采样: 0-1")
|
| 63 |
+
# 视频生成配置
|
| 64 |
+
video_config: Optional[VideoConfig] = Field(None, description="视频生成参数")
|
| 65 |
+
# 图片生成配置
|
| 66 |
+
image_config: Optional[ImageConfig] = Field(None, description="图片生成参数")
|
| 67 |
+
# Tool calling
|
| 68 |
+
tools: Optional[List[Dict[str, Any]]] = Field(None, description="Tool definitions")
|
| 69 |
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice: auto/required/none/specific")
|
| 70 |
+
parallel_tool_calls: Optional[bool] = Field(True, description="Allow parallel tool calls")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
VALID_ROLES = {"developer", "system", "user", "assistant", "tool"}
|
| 74 |
+
USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"}
|
| 75 |
+
ALLOWED_IMAGE_SIZES = {
|
| 76 |
+
"1280x720",
|
| 77 |
+
"720x1280",
|
| 78 |
+
"1792x1024",
|
| 79 |
+
"1024x1792",
|
| 80 |
+
"1024x1024",
|
| 81 |
+
}
|
| 82 |
+
IMAGINE_FAST_MODEL_ID = "grok-imagine-1.0-fast"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _validate_media_input(value: str, field_name: str, param: str):
|
| 86 |
+
"""Verify media input is a valid URL or data URI"""
|
| 87 |
+
if not isinstance(value, str) or not value.strip():
|
| 88 |
+
raise ValidationException(
|
| 89 |
+
message=f"{field_name} cannot be empty",
|
| 90 |
+
param=param,
|
| 91 |
+
code="empty_media",
|
| 92 |
+
)
|
| 93 |
+
value = value.strip()
|
| 94 |
+
if value.startswith("data:"):
|
| 95 |
+
return
|
| 96 |
+
if value.startswith("http://") or value.startswith("https://"):
|
| 97 |
+
return
|
| 98 |
+
candidate = "".join(value.split())
|
| 99 |
+
if len(candidate) >= 32 and len(candidate) % 4 == 0:
|
| 100 |
+
try:
|
| 101 |
+
base64.b64decode(candidate, validate=True)
|
| 102 |
+
raise ValidationException(
|
| 103 |
+
message=f"{field_name} base64 must be provided as a data URI (data:<mime>;base64,...)",
|
| 104 |
+
param=param,
|
| 105 |
+
code="invalid_media",
|
| 106 |
+
)
|
| 107 |
+
except binascii.Error:
|
| 108 |
+
pass
|
| 109 |
+
raise ValidationException(
|
| 110 |
+
message=f"{field_name} must be a URL or data URI",
|
| 111 |
+
param=param,
|
| 112 |
+
code="invalid_media",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]:
|
| 117 |
+
"""Extract prompt text and image URLs from messages"""
|
| 118 |
+
last_text = ""
|
| 119 |
+
image_urls: List[str] = []
|
| 120 |
+
|
| 121 |
+
for msg in messages:
|
| 122 |
+
role = msg.role or "user"
|
| 123 |
+
content = msg.content
|
| 124 |
+
if isinstance(content, str):
|
| 125 |
+
text = content.strip()
|
| 126 |
+
if text:
|
| 127 |
+
last_text = text
|
| 128 |
+
continue
|
| 129 |
+
if isinstance(content, dict):
|
| 130 |
+
content = [content]
|
| 131 |
+
if not isinstance(content, list):
|
| 132 |
+
continue
|
| 133 |
+
for block in content:
|
| 134 |
+
if not isinstance(block, dict):
|
| 135 |
+
continue
|
| 136 |
+
block_type = block.get("type")
|
| 137 |
+
if block_type == "text":
|
| 138 |
+
text = block.get("text", "")
|
| 139 |
+
if isinstance(text, str) and text.strip():
|
| 140 |
+
last_text = text.strip()
|
| 141 |
+
elif block_type == "image_url" and role == "user":
|
| 142 |
+
image = block.get("image_url") or {}
|
| 143 |
+
url = image.get("url", "")
|
| 144 |
+
if isinstance(url, str) and url.strip():
|
| 145 |
+
image_urls.append(url.strip())
|
| 146 |
+
|
| 147 |
+
return last_text, image_urls
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _resolve_image_format(value: Optional[str]) -> str:
|
| 151 |
+
fmt = value or get_config("app.image_format") or "url"
|
| 152 |
+
if isinstance(fmt, str):
|
| 153 |
+
fmt = fmt.lower()
|
| 154 |
+
if fmt == "base64":
|
| 155 |
+
return "b64_json"
|
| 156 |
+
if fmt in ("b64_json", "url"):
|
| 157 |
+
return fmt
|
| 158 |
+
raise ValidationException(
|
| 159 |
+
message="image_format must be one of url, base64, b64_json",
|
| 160 |
+
param="image_format",
|
| 161 |
+
code="invalid_image_format",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _image_field(response_format: str) -> str:
|
| 166 |
+
if response_format == "url":
|
| 167 |
+
return "url"
|
| 168 |
+
return "b64_json"
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _imagine_fast_server_image_config() -> ImageConfig:
|
| 172 |
+
"""Load server-side image generation parameters for grok-imagine-1.0-fast."""
|
| 173 |
+
n = int(get_config("imagine_fast.n", 1) or 1)
|
| 174 |
+
size = str(get_config("imagine_fast.size", "1024x1024") or "1024x1024")
|
| 175 |
+
response_format = str(
|
| 176 |
+
get_config("imagine_fast.response_format", get_config("app.image_format") or "url")
|
| 177 |
+
or "url"
|
| 178 |
+
)
|
| 179 |
+
return ImageConfig(n=n, size=size, response_format=response_format)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
async def _safe_sse_stream(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
|
| 183 |
+
"""Ensure streaming endpoints return SSE error payloads instead of transport-level 5xx breaks."""
|
| 184 |
+
try:
|
| 185 |
+
async for chunk in stream:
|
| 186 |
+
yield chunk
|
| 187 |
+
except AppException as e:
|
| 188 |
+
payload = {
|
| 189 |
+
"error": {
|
| 190 |
+
"message": e.message,
|
| 191 |
+
"type": e.error_type,
|
| 192 |
+
"code": e.code,
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
yield f"event: error\ndata: {orjson.dumps(payload).decode()}\n\n"
|
| 196 |
+
yield "data: [DONE]\n\n"
|
| 197 |
+
except Exception as e:
|
| 198 |
+
payload = {
|
| 199 |
+
"error": {
|
| 200 |
+
"message": str(e) or "stream_error",
|
| 201 |
+
"type": "server_error",
|
| 202 |
+
"code": "stream_error",
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
yield f"event: error\ndata: {orjson.dumps(payload).decode()}\n\n"
|
| 206 |
+
yield "data: [DONE]\n\n"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _streaming_error_response(exc: Exception) -> StreamingResponse:
|
| 210 |
+
if isinstance(exc, AppException):
|
| 211 |
+
payload = {
|
| 212 |
+
"error": {
|
| 213 |
+
"message": exc.message,
|
| 214 |
+
"type": exc.error_type,
|
| 215 |
+
"code": exc.code,
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
else:
|
| 219 |
+
payload = {
|
| 220 |
+
"error": {
|
| 221 |
+
"message": str(exc) or "stream_error",
|
| 222 |
+
"type": "server_error",
|
| 223 |
+
"code": "stream_error",
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
async def _one_shot_error() -> AsyncGenerator[str, None]:
|
| 228 |
+
yield f"event: error\ndata: {orjson.dumps(payload).decode()}\n\n"
|
| 229 |
+
yield "data: [DONE]\n\n"
|
| 230 |
+
|
| 231 |
+
return StreamingResponse(
|
| 232 |
+
_one_shot_error(),
|
| 233 |
+
media_type="text/event-stream",
|
| 234 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def _validate_image_config(image_conf: ImageConfig, *, stream: bool):
|
| 238 |
+
n = image_conf.n or 1
|
| 239 |
+
if n < 1 or n > 10:
|
| 240 |
+
raise ValidationException(
|
| 241 |
+
message="n must be between 1 and 10",
|
| 242 |
+
param="image_config.n",
|
| 243 |
+
code="invalid_n",
|
| 244 |
+
)
|
| 245 |
+
if stream and n not in (1, 2):
|
| 246 |
+
raise ValidationException(
|
| 247 |
+
message="Streaming is only supported when n=1 or n=2",
|
| 248 |
+
param="image_config.n",
|
| 249 |
+
code="invalid_stream_n",
|
| 250 |
+
)
|
| 251 |
+
if image_conf.response_format:
|
| 252 |
+
allowed_formats = {"b64_json", "base64", "url"}
|
| 253 |
+
if image_conf.response_format not in allowed_formats:
|
| 254 |
+
raise ValidationException(
|
| 255 |
+
message="response_format must be one of b64_json, base64, url",
|
| 256 |
+
param="image_config.response_format",
|
| 257 |
+
code="invalid_response_format",
|
| 258 |
+
)
|
| 259 |
+
if image_conf.size and image_conf.size not in ALLOWED_IMAGE_SIZES:
|
| 260 |
+
raise ValidationException(
|
| 261 |
+
message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}",
|
| 262 |
+
param="image_config.size",
|
| 263 |
+
code="invalid_size",
|
| 264 |
+
)
|
| 265 |
+
def validate_request(request: ChatCompletionRequest):
|
| 266 |
+
"""验证请求参数"""
|
| 267 |
+
# 验证模型
|
| 268 |
+
if not ModelService.valid(request.model):
|
| 269 |
+
raise ValidationException(
|
| 270 |
+
message=f"The model `{request.model}` does not exist or you do not have access to it.",
|
| 271 |
+
param="model",
|
| 272 |
+
code="model_not_found",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# 验证消息
|
| 276 |
+
for idx, msg in enumerate(request.messages):
|
| 277 |
+
if not isinstance(msg.role, str) or msg.role not in VALID_ROLES:
|
| 278 |
+
raise ValidationException(
|
| 279 |
+
message=f"role must be one of {sorted(VALID_ROLES)}",
|
| 280 |
+
param=f"messages.{idx}.role",
|
| 281 |
+
code="invalid_role",
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# tool role: requires tool_call_id, content can be None/empty
|
| 285 |
+
if msg.role == "tool":
|
| 286 |
+
if not msg.tool_call_id:
|
| 287 |
+
raise ValidationException(
|
| 288 |
+
message="tool messages must have a 'tool_call_id' field",
|
| 289 |
+
param=f"messages.{idx}.tool_call_id",
|
| 290 |
+
code="missing_tool_call_id",
|
| 291 |
+
)
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# assistant with tool_calls: content can be None
|
| 295 |
+
if msg.role == "assistant" and msg.tool_calls:
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
content = msg.content
|
| 299 |
+
|
| 300 |
+
# 兼容部分客户端会发送 assistant/tool 空内容(例如工具调用中间态)
|
| 301 |
+
if content is None:
|
| 302 |
+
if msg.role in {"assistant", "tool"}:
|
| 303 |
+
continue
|
| 304 |
+
raise ValidationException(
|
| 305 |
+
message="Message content cannot be null",
|
| 306 |
+
param=f"messages.{idx}.content",
|
| 307 |
+
code="empty_content",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# 字符串内容
|
| 311 |
+
if isinstance(content, str):
|
| 312 |
+
if not content.strip():
|
| 313 |
+
raise ValidationException(
|
| 314 |
+
message="Message content cannot be empty",
|
| 315 |
+
param=f"messages.{idx}.content",
|
| 316 |
+
code="empty_content",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# 列表内容
|
| 320 |
+
elif isinstance(content, dict):
|
| 321 |
+
content = [content]
|
| 322 |
+
for c_idx, item in enumerate(content):
|
| 323 |
+
if not isinstance(item, dict):
|
| 324 |
+
raise ValidationException(
|
| 325 |
+
message="Message content items must be objects",
|
| 326 |
+
param=f"messages.{idx}.content.{c_idx}",
|
| 327 |
+
code="invalid_content_item",
|
| 328 |
+
)
|
| 329 |
+
item_type = item.get("type")
|
| 330 |
+
if item_type != "text":
|
| 331 |
+
raise ValidationException(
|
| 332 |
+
message="When content is an object, type must be 'text'",
|
| 333 |
+
param=f"messages.{idx}.content.{c_idx}.type",
|
| 334 |
+
code="invalid_content_type",
|
| 335 |
+
)
|
| 336 |
+
text = item.get("text", "")
|
| 337 |
+
if not isinstance(text, str) or not text.strip():
|
| 338 |
+
raise ValidationException(
|
| 339 |
+
message="messages.%d.content.%d.text must be a non-empty string"
|
| 340 |
+
% (idx, c_idx),
|
| 341 |
+
param=f"messages.{idx}.content.{c_idx}.text",
|
| 342 |
+
code="empty_content",
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
elif isinstance(content, list):
|
| 346 |
+
if not content:
|
| 347 |
+
raise ValidationException(
|
| 348 |
+
message="Message content cannot be an empty array",
|
| 349 |
+
param=f"messages.{idx}.content",
|
| 350 |
+
code="empty_content",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
for block_idx, block in enumerate(content):
|
| 354 |
+
# 检查空对象
|
| 355 |
+
if not isinstance(block, dict):
|
| 356 |
+
raise ValidationException(
|
| 357 |
+
message="Content block must be an object",
|
| 358 |
+
param=f"messages.{idx}.content.{block_idx}",
|
| 359 |
+
code="invalid_block",
|
| 360 |
+
)
|
| 361 |
+
if not block:
|
| 362 |
+
raise ValidationException(
|
| 363 |
+
message="Content block cannot be empty",
|
| 364 |
+
param=f"messages.{idx}.content.{block_idx}",
|
| 365 |
+
code="empty_block",
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# 检查 type 字段
|
| 369 |
+
if "type" not in block:
|
| 370 |
+
raise ValidationException(
|
| 371 |
+
message="Content block must have a 'type' field",
|
| 372 |
+
param=f"messages.{idx}.content.{block_idx}",
|
| 373 |
+
code="missing_type",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
block_type = block.get("type")
|
| 377 |
+
|
| 378 |
+
# 检查 type 空值
|
| 379 |
+
if (
|
| 380 |
+
not block_type
|
| 381 |
+
or not isinstance(block_type, str)
|
| 382 |
+
or not block_type.strip()
|
| 383 |
+
):
|
| 384 |
+
raise ValidationException(
|
| 385 |
+
message="Content block 'type' cannot be empty",
|
| 386 |
+
param=f"messages.{idx}.content.{block_idx}.type",
|
| 387 |
+
code="empty_type",
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# 验证 type 有效性
|
| 391 |
+
if msg.role == "user":
|
| 392 |
+
if block_type not in USER_CONTENT_TYPES:
|
| 393 |
+
raise ValidationException(
|
| 394 |
+
message=f"Invalid content block type: '{block_type}'",
|
| 395 |
+
param=f"messages.{idx}.content.{block_idx}.type",
|
| 396 |
+
code="invalid_type",
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
if block_type != "text":
|
| 400 |
+
raise ValidationException(
|
| 401 |
+
message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'",
|
| 402 |
+
param=f"messages.{idx}.content.{block_idx}.type",
|
| 403 |
+
code="invalid_type",
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# 验证字段是否存在 & 非空
|
| 407 |
+
if block_type == "text":
|
| 408 |
+
text = block.get("text", "")
|
| 409 |
+
if not isinstance(text, str) or not text.strip():
|
| 410 |
+
raise ValidationException(
|
| 411 |
+
message="Text content cannot be empty",
|
| 412 |
+
param=f"messages.{idx}.content.{block_idx}.text",
|
| 413 |
+
code="empty_text",
|
| 414 |
+
)
|
| 415 |
+
elif block_type == "image_url":
|
| 416 |
+
image_url = block.get("image_url")
|
| 417 |
+
if not image_url or not isinstance(image_url, dict):
|
| 418 |
+
raise ValidationException(
|
| 419 |
+
message="image_url must have a 'url' field",
|
| 420 |
+
param=f"messages.{idx}.content.{block_idx}.image_url",
|
| 421 |
+
code="missing_url",
|
| 422 |
+
)
|
| 423 |
+
_validate_media_input(
|
| 424 |
+
image_url.get("url", ""),
|
| 425 |
+
"image_url.url",
|
| 426 |
+
f"messages.{idx}.content.{block_idx}.image_url.url",
|
| 427 |
+
)
|
| 428 |
+
elif block_type == "input_audio":
|
| 429 |
+
audio = block.get("input_audio")
|
| 430 |
+
if not audio or not isinstance(audio, dict):
|
| 431 |
+
raise ValidationException(
|
| 432 |
+
message="input_audio must have a 'data' field",
|
| 433 |
+
param=f"messages.{idx}.content.{block_idx}.input_audio",
|
| 434 |
+
code="missing_audio",
|
| 435 |
+
)
|
| 436 |
+
_validate_media_input(
|
| 437 |
+
audio.get("data", ""),
|
| 438 |
+
"input_audio.data",
|
| 439 |
+
f"messages.{idx}.content.{block_idx}.input_audio.data",
|
| 440 |
+
)
|
| 441 |
+
elif block_type == "file":
|
| 442 |
+
file_data = block.get("file")
|
| 443 |
+
if not file_data or not isinstance(file_data, dict):
|
| 444 |
+
raise ValidationException(
|
| 445 |
+
message="file must have a 'file_data' field",
|
| 446 |
+
param=f"messages.{idx}.content.{block_idx}.file",
|
| 447 |
+
code="missing_file",
|
| 448 |
+
)
|
| 449 |
+
_validate_media_input(
|
| 450 |
+
file_data.get("file_data", ""),
|
| 451 |
+
"file.file_data",
|
| 452 |
+
f"messages.{idx}.content.{block_idx}.file.file_data",
|
| 453 |
+
)
|
| 454 |
+
elif content is None:
|
| 455 |
+
raise ValidationException(
|
| 456 |
+
message="Message content cannot be empty",
|
| 457 |
+
param=f"messages.{idx}.content",
|
| 458 |
+
code="empty_content",
|
| 459 |
+
)
|
| 460 |
+
else:
|
| 461 |
+
raise ValidationException(
|
| 462 |
+
message="Message content must be a string or array",
|
| 463 |
+
param=f"messages.{idx}.content",
|
| 464 |
+
code="invalid_content",
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# 默认验证
|
| 468 |
+
if request.stream is not None:
|
| 469 |
+
if isinstance(request.stream, bool):
|
| 470 |
+
pass
|
| 471 |
+
elif isinstance(request.stream, str):
|
| 472 |
+
if request.stream.lower() in ("true", "1", "yes"):
|
| 473 |
+
request.stream = True
|
| 474 |
+
elif request.stream.lower() in ("false", "0", "no"):
|
| 475 |
+
request.stream = False
|
| 476 |
+
else:
|
| 477 |
+
raise ValidationException(
|
| 478 |
+
message="stream must be a boolean",
|
| 479 |
+
param="stream",
|
| 480 |
+
code="invalid_stream",
|
| 481 |
+
)
|
| 482 |
+
else:
|
| 483 |
+
raise ValidationException(
|
| 484 |
+
message="stream must be a boolean",
|
| 485 |
+
param="stream",
|
| 486 |
+
code="invalid_stream",
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
allowed_efforts = {"none", "minimal", "low", "medium", "high", "xhigh"}
|
| 490 |
+
if request.reasoning_effort is not None:
|
| 491 |
+
if not isinstance(request.reasoning_effort, str) or (
|
| 492 |
+
request.reasoning_effort not in allowed_efforts
|
| 493 |
+
):
|
| 494 |
+
raise ValidationException(
|
| 495 |
+
message=f"reasoning_effort must be one of {sorted(allowed_efforts)}",
|
| 496 |
+
param="reasoning_effort",
|
| 497 |
+
code="invalid_reasoning_effort",
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if request.temperature is None:
|
| 501 |
+
request.temperature = 0.8
|
| 502 |
+
else:
|
| 503 |
+
try:
|
| 504 |
+
request.temperature = float(request.temperature)
|
| 505 |
+
except Exception:
|
| 506 |
+
raise ValidationException(
|
| 507 |
+
message="temperature must be a float",
|
| 508 |
+
param="temperature",
|
| 509 |
+
code="invalid_temperature",
|
| 510 |
+
)
|
| 511 |
+
if not (0 <= request.temperature <= 2):
|
| 512 |
+
raise ValidationException(
|
| 513 |
+
message="temperature must be between 0 and 2",
|
| 514 |
+
param="temperature",
|
| 515 |
+
code="invalid_temperature",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
if request.top_p is None:
|
| 519 |
+
request.top_p = 0.95
|
| 520 |
+
else:
|
| 521 |
+
try:
|
| 522 |
+
request.top_p = float(request.top_p)
|
| 523 |
+
except Exception:
|
| 524 |
+
raise ValidationException(
|
| 525 |
+
message="top_p must be a float",
|
| 526 |
+
param="top_p",
|
| 527 |
+
code="invalid_top_p",
|
| 528 |
+
)
|
| 529 |
+
if not (0 <= request.top_p <= 1):
|
| 530 |
+
raise ValidationException(
|
| 531 |
+
message="top_p must be between 0 and 1",
|
| 532 |
+
param="top_p",
|
| 533 |
+
code="invalid_top_p",
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# 验证 tools
|
| 537 |
+
if request.tools is not None:
|
| 538 |
+
if not isinstance(request.tools, list):
|
| 539 |
+
raise ValidationException(
|
| 540 |
+
message="tools must be an array",
|
| 541 |
+
param="tools",
|
| 542 |
+
code="invalid_tools",
|
| 543 |
+
)
|
| 544 |
+
for t_idx, tool in enumerate(request.tools):
|
| 545 |
+
if not isinstance(tool, dict) or tool.get("type") != "function":
|
| 546 |
+
raise ValidationException(
|
| 547 |
+
message="Each tool must have type='function'",
|
| 548 |
+
param=f"tools.{t_idx}.type",
|
| 549 |
+
code="invalid_tool_type",
|
| 550 |
+
)
|
| 551 |
+
func = tool.get("function")
|
| 552 |
+
if not isinstance(func, dict) or not func.get("name"):
|
| 553 |
+
raise ValidationException(
|
| 554 |
+
message="Each tool function must have a 'name'",
|
| 555 |
+
param=f"tools.{t_idx}.function.name",
|
| 556 |
+
code="missing_function_name",
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# 验证 tool_choice
|
| 560 |
+
if request.tool_choice is not None:
|
| 561 |
+
if isinstance(request.tool_choice, str):
|
| 562 |
+
if request.tool_choice not in ("auto", "required", "none"):
|
| 563 |
+
raise ValidationException(
|
| 564 |
+
message="tool_choice must be 'auto', 'required', 'none', or a specific function object",
|
| 565 |
+
param="tool_choice",
|
| 566 |
+
code="invalid_tool_choice",
|
| 567 |
+
)
|
| 568 |
+
elif isinstance(request.tool_choice, dict):
|
| 569 |
+
if request.tool_choice.get("type") != "function" or not request.tool_choice.get("function", {}).get("name"):
|
| 570 |
+
raise ValidationException(
|
| 571 |
+
message="tool_choice object must have type='function' and function.name",
|
| 572 |
+
param="tool_choice",
|
| 573 |
+
code="invalid_tool_choice",
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
model_info = ModelService.get(request.model)
|
| 577 |
+
# image 验证
|
| 578 |
+
if model_info and (model_info.is_image or model_info.is_image_edit):
|
| 579 |
+
prompt, image_urls = _extract_prompt_images(request.messages)
|
| 580 |
+
if not prompt:
|
| 581 |
+
raise ValidationException(
|
| 582 |
+
message="Prompt cannot be empty",
|
| 583 |
+
param="messages",
|
| 584 |
+
code="empty_prompt",
|
| 585 |
+
)
|
| 586 |
+
image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig())
|
| 587 |
+
n = image_conf.n or 1
|
| 588 |
+
if not (1 <= n <= 10):
|
| 589 |
+
raise ValidationException(
|
| 590 |
+
message="n must be between 1 and 10",
|
| 591 |
+
param="image_config.n",
|
| 592 |
+
code="invalid_n",
|
| 593 |
+
)
|
| 594 |
+
if request.stream and n not in (1, 2):
|
| 595 |
+
raise ValidationException(
|
| 596 |
+
message="Streaming is only supported when n=1 or n=2",
|
| 597 |
+
param="stream",
|
| 598 |
+
code="invalid_stream_n",
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
response_format = _resolve_image_format(image_conf.response_format)
|
| 602 |
+
image_conf.n = n
|
| 603 |
+
image_conf.response_format = response_format
|
| 604 |
+
if not image_conf.size:
|
| 605 |
+
image_conf.size = "1024x1024"
|
| 606 |
+
allowed_sizes = {
|
| 607 |
+
"1280x720",
|
| 608 |
+
"720x1280",
|
| 609 |
+
"1792x1024",
|
| 610 |
+
"1024x1792",
|
| 611 |
+
"1024x1024",
|
| 612 |
+
}
|
| 613 |
+
if image_conf.size not in allowed_sizes:
|
| 614 |
+
raise ValidationException(
|
| 615 |
+
message=f"size must be one of {sorted(allowed_sizes)}",
|
| 616 |
+
param="image_config.size",
|
| 617 |
+
code="invalid_size",
|
| 618 |
+
)
|
| 619 |
+
request.image_config = image_conf
|
| 620 |
+
|
| 621 |
+
# image edit 验证
|
| 622 |
+
if model_info and model_info.is_image_edit:
|
| 623 |
+
_, image_urls = _extract_prompt_images(request.messages)
|
| 624 |
+
if not image_urls:
|
| 625 |
+
raise ValidationException(
|
| 626 |
+
message="image_url is required for image edits",
|
| 627 |
+
param="messages",
|
| 628 |
+
code="missing_image",
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# video 验证
|
| 632 |
+
if model_info and model_info.is_video:
|
| 633 |
+
config = request.video_config or VideoConfig()
|
| 634 |
+
ratio_map = {
|
| 635 |
+
"1280x720": "16:9",
|
| 636 |
+
"720x1280": "9:16",
|
| 637 |
+
"1792x1024": "3:2",
|
| 638 |
+
"1024x1792": "2:3",
|
| 639 |
+
"1024x1024": "1:1",
|
| 640 |
+
"16:9": "16:9",
|
| 641 |
+
"9:16": "9:16",
|
| 642 |
+
"3:2": "3:2",
|
| 643 |
+
"2:3": "2:3",
|
| 644 |
+
"1:1": "1:1",
|
| 645 |
+
}
|
| 646 |
+
if config.aspect_ratio is None:
|
| 647 |
+
config.aspect_ratio = "3:2"
|
| 648 |
+
if config.aspect_ratio not in ratio_map:
|
| 649 |
+
raise ValidationException(
|
| 650 |
+
message=f"aspect_ratio must be one of {list(ratio_map.keys())}",
|
| 651 |
+
param="video_config.aspect_ratio",
|
| 652 |
+
code="invalid_aspect_ratio",
|
| 653 |
+
)
|
| 654 |
+
config.aspect_ratio = ratio_map[config.aspect_ratio]
|
| 655 |
+
|
| 656 |
+
if config.video_length not in (6, 10, 15):
|
| 657 |
+
raise ValidationException(
|
| 658 |
+
message="video_length must be 6, 10, or 15 seconds",
|
| 659 |
+
param="video_config.video_length",
|
| 660 |
+
code="invalid_video_length",
|
| 661 |
+
)
|
| 662 |
+
if config.resolution_name not in ("480p", "720p"):
|
| 663 |
+
raise ValidationException(
|
| 664 |
+
message="resolution_name must be one of ['480p', '720p']",
|
| 665 |
+
param="video_config.resolution_name",
|
| 666 |
+
code="invalid_resolution",
|
| 667 |
+
)
|
| 668 |
+
if config.preset not in ("fun", "normal", "spicy", "custom"):
|
| 669 |
+
raise ValidationException(
|
| 670 |
+
message="preset must be one of ['fun', 'normal', 'spicy', 'custom']",
|
| 671 |
+
param="video_config.preset",
|
| 672 |
+
code="invalid_preset",
|
| 673 |
+
)
|
| 674 |
+
request.video_config = config
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
router = APIRouter(tags=["Chat"])
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
@router.post("/chat/completions")
|
| 681 |
+
async def chat_completions(request: ChatCompletionRequest):
|
| 682 |
+
"""Chat Completions API - 兼容 OpenAI"""
|
| 683 |
+
from app.core.logger import logger
|
| 684 |
+
|
| 685 |
+
# 参数验证
|
| 686 |
+
validate_request(request)
|
| 687 |
+
|
| 688 |
+
logger.debug(f"Chat request: model={request.model}, stream={request.stream}")
|
| 689 |
+
|
| 690 |
+
# 检测模型类型
|
| 691 |
+
model_info = ModelService.get(request.model)
|
| 692 |
+
if model_info and model_info.is_image_edit:
|
| 693 |
+
prompt, image_urls = _extract_prompt_images(request.messages)
|
| 694 |
+
if not image_urls:
|
| 695 |
+
raise ValidationException(
|
| 696 |
+
message="Image is required",
|
| 697 |
+
param="image",
|
| 698 |
+
code="missing_image",
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
is_stream = (
|
| 702 |
+
request.stream if request.stream is not None else get_config("app.stream")
|
| 703 |
+
)
|
| 704 |
+
image_conf = request.image_config or ImageConfig()
|
| 705 |
+
_validate_image_config(image_conf, stream=bool(is_stream))
|
| 706 |
+
response_format = _resolve_image_format(image_conf.response_format)
|
| 707 |
+
response_field = _image_field(response_format)
|
| 708 |
+
n = image_conf.n or 1
|
| 709 |
+
|
| 710 |
+
token_mgr = await get_token_manager()
|
| 711 |
+
await token_mgr.reload_if_stale()
|
| 712 |
+
|
| 713 |
+
token = None
|
| 714 |
+
for pool_name in ModelService.pool_candidates_for_model(request.model):
|
| 715 |
+
token = token_mgr.get_token(pool_name)
|
| 716 |
+
if token:
|
| 717 |
+
break
|
| 718 |
+
|
| 719 |
+
if not token:
|
| 720 |
+
raise AppException(
|
| 721 |
+
message="No available tokens. Please try again later.",
|
| 722 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 723 |
+
code="rate_limit_exceeded",
|
| 724 |
+
status_code=429,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
result = await ImageEditService().edit(
|
| 728 |
+
token_mgr=token_mgr,
|
| 729 |
+
token=token,
|
| 730 |
+
model_info=model_info,
|
| 731 |
+
prompt=prompt,
|
| 732 |
+
images=image_urls,
|
| 733 |
+
n=n,
|
| 734 |
+
response_format=response_format,
|
| 735 |
+
stream=bool(is_stream),
|
| 736 |
+
chat_format=True,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
if result.stream:
|
| 740 |
+
return StreamingResponse(
|
| 741 |
+
_safe_sse_stream(result.data),
|
| 742 |
+
media_type="text/event-stream",
|
| 743 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
content = result.data[0] if result.data else ""
|
| 747 |
+
return JSONResponse(
|
| 748 |
+
content=make_chat_response(request.model, content)
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
if model_info and model_info.is_image:
|
| 752 |
+
prompt, _ = _extract_prompt_images(request.messages)
|
| 753 |
+
|
| 754 |
+
is_stream = (
|
| 755 |
+
request.stream if request.stream is not None else get_config("app.stream")
|
| 756 |
+
)
|
| 757 |
+
image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig())
|
| 758 |
+
_validate_image_config(image_conf, stream=bool(is_stream))
|
| 759 |
+
response_format = _resolve_image_format(image_conf.response_format)
|
| 760 |
+
response_field = _image_field(response_format)
|
| 761 |
+
n = image_conf.n or 1
|
| 762 |
+
size = image_conf.size or "1024x1024"
|
| 763 |
+
aspect_ratio_map = {
|
| 764 |
+
"1280x720": "16:9",
|
| 765 |
+
"720x1280": "9:16",
|
| 766 |
+
"1792x1024": "3:2",
|
| 767 |
+
"1024x1792": "2:3",
|
| 768 |
+
"1024x1024": "1:1",
|
| 769 |
+
}
|
| 770 |
+
aspect_ratio = aspect_ratio_map.get(size, "2:3")
|
| 771 |
+
|
| 772 |
+
token_mgr = await get_token_manager()
|
| 773 |
+
await token_mgr.reload_if_stale()
|
| 774 |
+
|
| 775 |
+
token = None
|
| 776 |
+
for pool_name in ModelService.pool_candidates_for_model(request.model):
|
| 777 |
+
token = token_mgr.get_token(pool_name)
|
| 778 |
+
if token:
|
| 779 |
+
break
|
| 780 |
+
|
| 781 |
+
if not token:
|
| 782 |
+
raise AppException(
|
| 783 |
+
message="No available tokens. Please try again later.",
|
| 784 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 785 |
+
code="rate_limit_exceeded",
|
| 786 |
+
status_code=429,
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
result = await ImageGenerationService().generate(
|
| 790 |
+
token_mgr=token_mgr,
|
| 791 |
+
token=token,
|
| 792 |
+
model_info=model_info,
|
| 793 |
+
prompt=prompt,
|
| 794 |
+
n=n,
|
| 795 |
+
response_format=response_format,
|
| 796 |
+
size=size,
|
| 797 |
+
aspect_ratio=aspect_ratio,
|
| 798 |
+
stream=bool(is_stream),
|
| 799 |
+
chat_format=True,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if result.stream:
|
| 803 |
+
return StreamingResponse(
|
| 804 |
+
_safe_sse_stream(result.data),
|
| 805 |
+
media_type="text/event-stream",
|
| 806 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
content = result.data[0] if result.data else ""
|
| 810 |
+
usage = result.usage_override
|
| 811 |
+
return JSONResponse(
|
| 812 |
+
content=make_chat_response(request.model, content, usage=usage)
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
if model_info and model_info.is_video:
|
| 816 |
+
# 提取视频配置 (默认值在 Pydantic 模型中处理)
|
| 817 |
+
v_conf = request.video_config or VideoConfig()
|
| 818 |
+
|
| 819 |
+
try:
|
| 820 |
+
result = await VideoService.completions(
|
| 821 |
+
model=request.model,
|
| 822 |
+
messages=[msg.model_dump() for msg in request.messages],
|
| 823 |
+
stream=request.stream,
|
| 824 |
+
reasoning_effort=request.reasoning_effort,
|
| 825 |
+
aspect_ratio=v_conf.aspect_ratio,
|
| 826 |
+
video_length=v_conf.video_length,
|
| 827 |
+
resolution=v_conf.resolution_name,
|
| 828 |
+
preset=v_conf.preset,
|
| 829 |
+
)
|
| 830 |
+
except Exception as e:
|
| 831 |
+
if request.stream is not False:
|
| 832 |
+
return _streaming_error_response(e)
|
| 833 |
+
raise
|
| 834 |
+
else:
|
| 835 |
+
try:
|
| 836 |
+
result = await ChatService.completions(
|
| 837 |
+
model=request.model,
|
| 838 |
+
messages=[msg.model_dump() for msg in request.messages],
|
| 839 |
+
stream=request.stream,
|
| 840 |
+
reasoning_effort=request.reasoning_effort,
|
| 841 |
+
temperature=request.temperature,
|
| 842 |
+
top_p=request.top_p,
|
| 843 |
+
tools=request.tools,
|
| 844 |
+
tool_choice=request.tool_choice,
|
| 845 |
+
parallel_tool_calls=request.parallel_tool_calls,
|
| 846 |
+
)
|
| 847 |
+
except Exception as e:
|
| 848 |
+
if request.stream is not False:
|
| 849 |
+
return _streaming_error_response(e)
|
| 850 |
+
raise
|
| 851 |
+
|
| 852 |
+
if isinstance(result, dict):
|
| 853 |
+
return JSONResponse(content=result)
|
| 854 |
+
else:
|
| 855 |
+
return StreamingResponse(
|
| 856 |
+
_safe_sse_stream(result),
|
| 857 |
+
media_type="text/event-stream",
|
| 858 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
__all__ = ["router"]
|
app/api/v1/files.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文件服务 API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import aiofiles.os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from fastapi import APIRouter, HTTPException
|
| 8 |
+
from fastapi.responses import FileResponse
|
| 9 |
+
|
| 10 |
+
from app.core.logger import logger
|
| 11 |
+
from app.core.storage import DATA_DIR
|
| 12 |
+
|
| 13 |
+
router = APIRouter(tags=["Files"])
|
| 14 |
+
|
| 15 |
+
# 缓存根目录
|
| 16 |
+
BASE_DIR = DATA_DIR / "tmp"
|
| 17 |
+
IMAGE_DIR = BASE_DIR / "image"
|
| 18 |
+
VIDEO_DIR = BASE_DIR / "video"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@router.get("/image/{filename:path}")
|
| 22 |
+
async def get_image(filename: str):
|
| 23 |
+
"""
|
| 24 |
+
获取图片文件
|
| 25 |
+
"""
|
| 26 |
+
if "/" in filename:
|
| 27 |
+
filename = filename.replace("/", "-")
|
| 28 |
+
|
| 29 |
+
file_path = IMAGE_DIR / filename
|
| 30 |
+
|
| 31 |
+
if await aiofiles.os.path.exists(file_path):
|
| 32 |
+
if await aiofiles.os.path.isfile(file_path):
|
| 33 |
+
content_type = "image/jpeg"
|
| 34 |
+
if file_path.suffix.lower() == ".png":
|
| 35 |
+
content_type = "image/png"
|
| 36 |
+
elif file_path.suffix.lower() == ".webp":
|
| 37 |
+
content_type = "image/webp"
|
| 38 |
+
|
| 39 |
+
# 增加缓存头,支持高并发场景下的浏览器/CDN缓存
|
| 40 |
+
return FileResponse(
|
| 41 |
+
file_path,
|
| 42 |
+
media_type=content_type,
|
| 43 |
+
headers={"Cache-Control": "public, max-age=31536000, immutable"},
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
logger.warning(f"Image not found: {filename}")
|
| 47 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@router.get("/video/{filename:path}")
|
| 51 |
+
async def get_video(filename: str):
|
| 52 |
+
"""
|
| 53 |
+
获取视频文件
|
| 54 |
+
"""
|
| 55 |
+
if "/" in filename:
|
| 56 |
+
filename = filename.replace("/", "-")
|
| 57 |
+
|
| 58 |
+
file_path = VIDEO_DIR / filename
|
| 59 |
+
|
| 60 |
+
if await aiofiles.os.path.exists(file_path):
|
| 61 |
+
if await aiofiles.os.path.isfile(file_path):
|
| 62 |
+
return FileResponse(
|
| 63 |
+
file_path,
|
| 64 |
+
media_type="video/mp4",
|
| 65 |
+
headers={"Cache-Control": "public, max-age=31536000, immutable"},
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
logger.warning(f"Video not found: {filename}")
|
| 69 |
+
raise HTTPException(status_code=404, detail="Video not found")
|
app/api/v1/image.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image Generation API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
|
| 10 |
+
from fastapi import APIRouter, File, Form, UploadFile
|
| 11 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 12 |
+
from pydantic import BaseModel, Field, ValidationError
|
| 13 |
+
|
| 14 |
+
from app.services.grok.services.image import ImageGenerationService
|
| 15 |
+
from app.services.grok.services.image_edit import ImageEditService
|
| 16 |
+
from app.services.grok.services.model import ModelService
|
| 17 |
+
from app.services.token import get_token_manager
|
| 18 |
+
from app.core.exceptions import ValidationException, AppException, ErrorType
|
| 19 |
+
from app.core.config import get_config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
router = APIRouter(tags=["Images"])
|
| 23 |
+
|
| 24 |
+
ALLOWED_IMAGE_SIZES = {
|
| 25 |
+
"1280x720",
|
| 26 |
+
"720x1280",
|
| 27 |
+
"1792x1024",
|
| 28 |
+
"1024x1792",
|
| 29 |
+
"1024x1024",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
SIZE_TO_ASPECT = {
|
| 33 |
+
"1280x720": "16:9",
|
| 34 |
+
"720x1280": "9:16",
|
| 35 |
+
"1792x1024": "3:2",
|
| 36 |
+
"1024x1792": "2:3",
|
| 37 |
+
"1024x1024": "1:1",
|
| 38 |
+
}
|
| 39 |
+
ALLOWED_ASPECT_RATIOS = {"1:1", "2:3", "3:2", "9:16", "16:9"}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ImageGenerationRequest(BaseModel):
|
| 43 |
+
"""图片生成请求 - OpenAI 兼容"""
|
| 44 |
+
|
| 45 |
+
prompt: str = Field(..., description="图片描述")
|
| 46 |
+
model: Optional[str] = Field("grok-imagine-1.0", description="模型名称")
|
| 47 |
+
n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
|
| 48 |
+
size: Optional[str] = Field(
|
| 49 |
+
"1024x1024",
|
| 50 |
+
description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024",
|
| 51 |
+
)
|
| 52 |
+
quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)")
|
| 53 |
+
response_format: Optional[str] = Field(None, description="响应格式")
|
| 54 |
+
style: Optional[str] = Field(None, description="风格 (暂不支持)")
|
| 55 |
+
stream: Optional[bool] = Field(False, description="是否流式输出")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ImageEditRequest(BaseModel):
|
| 59 |
+
"""图片编辑请求 - OpenAI 兼容"""
|
| 60 |
+
|
| 61 |
+
prompt: str = Field(..., description="编辑描述")
|
| 62 |
+
model: Optional[str] = Field("grok-imagine-1.0-edit", description="模型名称")
|
| 63 |
+
image: Optional[Union[str, List[str]]] = Field(None, description="待编辑图片文件")
|
| 64 |
+
n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
|
| 65 |
+
size: Optional[str] = Field(
|
| 66 |
+
"1024x1024",
|
| 67 |
+
description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024",
|
| 68 |
+
)
|
| 69 |
+
quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)")
|
| 70 |
+
response_format: Optional[str] = Field(None, description="响应格式")
|
| 71 |
+
style: Optional[str] = Field(None, description="风格 (暂不支持)")
|
| 72 |
+
stream: Optional[bool] = Field(False, description="是否流式输出")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _validate_common_request(
|
| 76 |
+
request: Union[ImageGenerationRequest, ImageEditRequest],
|
| 77 |
+
*,
|
| 78 |
+
allow_ws_stream: bool = False,
|
| 79 |
+
):
|
| 80 |
+
"""通用参数校验"""
|
| 81 |
+
# 验证 prompt
|
| 82 |
+
if not request.prompt or not request.prompt.strip():
|
| 83 |
+
raise ValidationException(
|
| 84 |
+
message="Prompt cannot be empty", param="prompt", code="empty_prompt"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 验证 n 参数范围
|
| 88 |
+
if request.n < 1 or request.n > 10:
|
| 89 |
+
raise ValidationException(
|
| 90 |
+
message="n must be between 1 and 10", param="n", code="invalid_n"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# 流式只支持 n=1 或 n=2
|
| 94 |
+
if request.stream and request.n not in [1, 2]:
|
| 95 |
+
raise ValidationException(
|
| 96 |
+
message="Streaming is only supported when n=1 or n=2",
|
| 97 |
+
param="stream",
|
| 98 |
+
code="invalid_stream_n",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if allow_ws_stream:
|
| 102 |
+
if request.stream and request.response_format:
|
| 103 |
+
allowed_stream_formats = {"b64_json", "base64", "url"}
|
| 104 |
+
if request.response_format not in allowed_stream_formats:
|
| 105 |
+
raise ValidationException(
|
| 106 |
+
message="Streaming only supports response_format=b64_json/base64/url",
|
| 107 |
+
param="response_format",
|
| 108 |
+
code="invalid_response_format",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if request.response_format:
|
| 112 |
+
allowed_formats = {"b64_json", "base64", "url"}
|
| 113 |
+
if request.response_format not in allowed_formats:
|
| 114 |
+
raise ValidationException(
|
| 115 |
+
message=f"response_format must be one of {sorted(allowed_formats)}",
|
| 116 |
+
param="response_format",
|
| 117 |
+
code="invalid_response_format",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if request.size and request.size not in ALLOWED_IMAGE_SIZES:
|
| 121 |
+
raise ValidationException(
|
| 122 |
+
message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}",
|
| 123 |
+
param="size",
|
| 124 |
+
code="invalid_size",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def validate_generation_request(request: ImageGenerationRequest):
|
| 129 |
+
"""验证图片生成请求参数"""
|
| 130 |
+
if request.model != "grok-imagine-1.0":
|
| 131 |
+
raise ValidationException(
|
| 132 |
+
message="The model `grok-imagine-1.0` is required for image generation.",
|
| 133 |
+
param="model",
|
| 134 |
+
code="model_not_supported",
|
| 135 |
+
)
|
| 136 |
+
# 验证模型 - 通过 is_image 检查
|
| 137 |
+
model_info = ModelService.get(request.model)
|
| 138 |
+
if not model_info or not model_info.is_image:
|
| 139 |
+
# 获取支持的图片模型列表
|
| 140 |
+
image_models = [m.model_id for m in ModelService.MODELS if m.is_image]
|
| 141 |
+
raise ValidationException(
|
| 142 |
+
message=(
|
| 143 |
+
f"The model `{request.model}` is not supported for image generation. "
|
| 144 |
+
f"Supported: {image_models}"
|
| 145 |
+
),
|
| 146 |
+
param="model",
|
| 147 |
+
code="model_not_supported",
|
| 148 |
+
)
|
| 149 |
+
_validate_common_request(request, allow_ws_stream=True)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def resolve_response_format(response_format: Optional[str]) -> str:
|
| 153 |
+
"""解析响应格式"""
|
| 154 |
+
fmt = response_format or get_config("app.image_format")
|
| 155 |
+
if isinstance(fmt, str):
|
| 156 |
+
fmt = fmt.lower()
|
| 157 |
+
if fmt in ("b64_json", "base64", "url"):
|
| 158 |
+
return fmt
|
| 159 |
+
raise ValidationException(
|
| 160 |
+
message="response_format must be one of b64_json, base64, url",
|
| 161 |
+
param="response_format",
|
| 162 |
+
code="invalid_response_format",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def response_field_name(response_format: str) -> str:
|
| 167 |
+
"""获取响应字段名"""
|
| 168 |
+
return {"url": "url", "base64": "base64"}.get(response_format, "b64_json")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def resolve_aspect_ratio(size: str) -> str:
|
| 172 |
+
"""Map OpenAI size to Grok Imagine aspect ratio."""
|
| 173 |
+
value = (size or "").strip()
|
| 174 |
+
if not value:
|
| 175 |
+
return "2:3"
|
| 176 |
+
if value in SIZE_TO_ASPECT:
|
| 177 |
+
return SIZE_TO_ASPECT[value]
|
| 178 |
+
if ":" in value:
|
| 179 |
+
try:
|
| 180 |
+
left, right = value.split(":", 1)
|
| 181 |
+
left_i = int(left.strip())
|
| 182 |
+
right_i = int(right.strip())
|
| 183 |
+
if left_i > 0 and right_i > 0:
|
| 184 |
+
ratio = f"{left_i}:{right_i}"
|
| 185 |
+
if ratio in ALLOWED_ASPECT_RATIOS:
|
| 186 |
+
return ratio
|
| 187 |
+
except (TypeError, ValueError):
|
| 188 |
+
pass
|
| 189 |
+
return "2:3"
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]):
|
| 193 |
+
"""验证图片编辑请求参数"""
|
| 194 |
+
if request.model != "grok-imagine-1.0-edit":
|
| 195 |
+
raise ValidationException(
|
| 196 |
+
message=("The model `grok-imagine-1.0-edit` is required for image edits."),
|
| 197 |
+
param="model",
|
| 198 |
+
code="model_not_supported",
|
| 199 |
+
)
|
| 200 |
+
model_info = ModelService.get(request.model)
|
| 201 |
+
if not model_info or not model_info.is_image_edit:
|
| 202 |
+
edit_models = [m.model_id for m in ModelService.MODELS if m.is_image_edit]
|
| 203 |
+
raise ValidationException(
|
| 204 |
+
message=(
|
| 205 |
+
f"The model `{request.model}` is not supported for image edits. "
|
| 206 |
+
f"Supported: {edit_models}"
|
| 207 |
+
),
|
| 208 |
+
param="model",
|
| 209 |
+
code="model_not_supported",
|
| 210 |
+
)
|
| 211 |
+
_validate_common_request(request, allow_ws_stream=False)
|
| 212 |
+
if not images:
|
| 213 |
+
raise ValidationException(
|
| 214 |
+
message="Image is required",
|
| 215 |
+
param="image",
|
| 216 |
+
code="missing_image",
|
| 217 |
+
)
|
| 218 |
+
if len(images) > 16:
|
| 219 |
+
raise ValidationException(
|
| 220 |
+
message="Too many images. Maximum is 16.",
|
| 221 |
+
param="image",
|
| 222 |
+
code="invalid_image_count",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
async def _get_token(model: str):
|
| 227 |
+
"""获取可用 token"""
|
| 228 |
+
token_mgr = await get_token_manager()
|
| 229 |
+
await token_mgr.reload_if_stale()
|
| 230 |
+
|
| 231 |
+
token = None
|
| 232 |
+
for pool_name in ModelService.pool_candidates_for_model(model):
|
| 233 |
+
token = token_mgr.get_token(pool_name)
|
| 234 |
+
if token:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
if not token:
|
| 238 |
+
raise AppException(
|
| 239 |
+
message="No available tokens. Please try again later.",
|
| 240 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 241 |
+
code="rate_limit_exceeded",
|
| 242 |
+
status_code=429,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return token_mgr, token
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@router.post("/images/generations")
|
| 249 |
+
async def create_image(request: ImageGenerationRequest):
|
| 250 |
+
"""
|
| 251 |
+
Image Generation API
|
| 252 |
+
|
| 253 |
+
流式响应格式:
|
| 254 |
+
- event: image_generation.partial_image
|
| 255 |
+
- event: image_generation.completed
|
| 256 |
+
|
| 257 |
+
非流式响应格式:
|
| 258 |
+
- {"created": ..., "data": [{"b64_json": "..."}], "usage": {...}}
|
| 259 |
+
"""
|
| 260 |
+
# stream 默认为 false
|
| 261 |
+
if request.stream is None:
|
| 262 |
+
request.stream = False
|
| 263 |
+
|
| 264 |
+
if request.response_format is None:
|
| 265 |
+
request.response_format = resolve_response_format(None)
|
| 266 |
+
|
| 267 |
+
# 参数验证
|
| 268 |
+
validate_generation_request(request)
|
| 269 |
+
|
| 270 |
+
# 兼容 base64/b64_json
|
| 271 |
+
if request.response_format == "base64":
|
| 272 |
+
request.response_format = "b64_json"
|
| 273 |
+
|
| 274 |
+
response_format = resolve_response_format(request.response_format)
|
| 275 |
+
response_field = response_field_name(response_format)
|
| 276 |
+
|
| 277 |
+
# 获取 token 和模型信息
|
| 278 |
+
token_mgr, token = await _get_token(request.model)
|
| 279 |
+
model_info = ModelService.get(request.model)
|
| 280 |
+
aspect_ratio = resolve_aspect_ratio(request.size)
|
| 281 |
+
|
| 282 |
+
result = await ImageGenerationService().generate(
|
| 283 |
+
token_mgr=token_mgr,
|
| 284 |
+
token=token,
|
| 285 |
+
model_info=model_info,
|
| 286 |
+
prompt=request.prompt,
|
| 287 |
+
n=request.n,
|
| 288 |
+
response_format=response_format,
|
| 289 |
+
size=request.size,
|
| 290 |
+
aspect_ratio=aspect_ratio,
|
| 291 |
+
stream=bool(request.stream),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if result.stream:
|
| 295 |
+
return StreamingResponse(
|
| 296 |
+
result.data,
|
| 297 |
+
media_type="text/event-stream",
|
| 298 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
data = [{response_field: img} for img in result.data]
|
| 302 |
+
usage = result.usage_override or {
|
| 303 |
+
"total_tokens": 0,
|
| 304 |
+
"input_tokens": 0,
|
| 305 |
+
"output_tokens": 0,
|
| 306 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
return JSONResponse(
|
| 310 |
+
content={
|
| 311 |
+
"created": int(time.time()),
|
| 312 |
+
"data": data,
|
| 313 |
+
"usage": usage,
|
| 314 |
+
}
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@router.post("/images/edits")
|
| 319 |
+
async def edit_image(
|
| 320 |
+
prompt: str = Form(...),
|
| 321 |
+
image: List[UploadFile] = File(...),
|
| 322 |
+
model: Optional[str] = Form("grok-imagine-1.0-edit"),
|
| 323 |
+
n: int = Form(1),
|
| 324 |
+
size: str = Form("1024x1024"),
|
| 325 |
+
quality: str = Form("standard"),
|
| 326 |
+
response_format: Optional[str] = Form(None),
|
| 327 |
+
style: Optional[str] = Form(None),
|
| 328 |
+
stream: Optional[bool] = Form(False),
|
| 329 |
+
):
|
| 330 |
+
"""
|
| 331 |
+
Image Edits API
|
| 332 |
+
|
| 333 |
+
同官方 API 格式,仅支持 multipart/form-data 文件上传
|
| 334 |
+
"""
|
| 335 |
+
if response_format is None:
|
| 336 |
+
response_format = resolve_response_format(None)
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
edit_request = ImageEditRequest(
|
| 340 |
+
prompt=prompt,
|
| 341 |
+
model=model,
|
| 342 |
+
n=n,
|
| 343 |
+
size=size,
|
| 344 |
+
quality=quality,
|
| 345 |
+
response_format=response_format,
|
| 346 |
+
style=style,
|
| 347 |
+
stream=stream,
|
| 348 |
+
)
|
| 349 |
+
except ValidationError as exc:
|
| 350 |
+
errors = exc.errors()
|
| 351 |
+
if errors:
|
| 352 |
+
first = errors[0]
|
| 353 |
+
loc = first.get("loc", [])
|
| 354 |
+
msg = first.get("msg", "Invalid request")
|
| 355 |
+
code = first.get("type", "invalid_value")
|
| 356 |
+
param_parts = [
|
| 357 |
+
str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())
|
| 358 |
+
]
|
| 359 |
+
param = ".".join(param_parts) if param_parts else None
|
| 360 |
+
raise ValidationException(message=msg, param=param, code=code)
|
| 361 |
+
raise ValidationException(message="Invalid request", code="invalid_value")
|
| 362 |
+
|
| 363 |
+
if edit_request.stream is None:
|
| 364 |
+
edit_request.stream = False
|
| 365 |
+
|
| 366 |
+
response_format = resolve_response_format(edit_request.response_format)
|
| 367 |
+
if response_format == "base64":
|
| 368 |
+
response_format = "b64_json"
|
| 369 |
+
edit_request.response_format = response_format
|
| 370 |
+
response_field = response_field_name(response_format)
|
| 371 |
+
|
| 372 |
+
# 参数验证
|
| 373 |
+
validate_edit_request(edit_request, image)
|
| 374 |
+
|
| 375 |
+
max_image_bytes = 50 * 1024 * 1024
|
| 376 |
+
allowed_types = {"image/png", "image/jpeg", "image/webp", "image/jpg"}
|
| 377 |
+
|
| 378 |
+
images: List[str] = []
|
| 379 |
+
for item in image:
|
| 380 |
+
content = await item.read()
|
| 381 |
+
await item.close()
|
| 382 |
+
if not content:
|
| 383 |
+
raise ValidationException(
|
| 384 |
+
message="File content is empty",
|
| 385 |
+
param="image",
|
| 386 |
+
code="empty_file",
|
| 387 |
+
)
|
| 388 |
+
if len(content) > max_image_bytes:
|
| 389 |
+
raise ValidationException(
|
| 390 |
+
message="Image file too large. Maximum is 50MB.",
|
| 391 |
+
param="image",
|
| 392 |
+
code="file_too_large",
|
| 393 |
+
)
|
| 394 |
+
mime = (item.content_type or "").lower()
|
| 395 |
+
if mime == "image/jpg":
|
| 396 |
+
mime = "image/jpeg"
|
| 397 |
+
ext = Path(item.filename or "").suffix.lower()
|
| 398 |
+
if mime not in allowed_types:
|
| 399 |
+
if ext in (".jpg", ".jpeg"):
|
| 400 |
+
mime = "image/jpeg"
|
| 401 |
+
elif ext == ".png":
|
| 402 |
+
mime = "image/png"
|
| 403 |
+
elif ext == ".webp":
|
| 404 |
+
mime = "image/webp"
|
| 405 |
+
else:
|
| 406 |
+
raise ValidationException(
|
| 407 |
+
message="Unsupported image type. Supported: png, jpg, webp.",
|
| 408 |
+
param="image",
|
| 409 |
+
code="invalid_image_type",
|
| 410 |
+
)
|
| 411 |
+
b64 = base64.b64encode(content).decode()
|
| 412 |
+
images.append(f"data:{mime};base64,{b64}")
|
| 413 |
+
|
| 414 |
+
# 获取 token 和模型信息
|
| 415 |
+
token_mgr, token = await _get_token(edit_request.model)
|
| 416 |
+
model_info = ModelService.get(edit_request.model)
|
| 417 |
+
|
| 418 |
+
result = await ImageEditService().edit(
|
| 419 |
+
token_mgr=token_mgr,
|
| 420 |
+
token=token,
|
| 421 |
+
model_info=model_info,
|
| 422 |
+
prompt=edit_request.prompt,
|
| 423 |
+
images=images,
|
| 424 |
+
n=edit_request.n,
|
| 425 |
+
response_format=response_format,
|
| 426 |
+
stream=bool(edit_request.stream),
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
if result.stream:
|
| 430 |
+
return StreamingResponse(
|
| 431 |
+
result.data,
|
| 432 |
+
media_type="text/event-stream",
|
| 433 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
data = [{response_field: img} for img in result.data]
|
| 437 |
+
|
| 438 |
+
return JSONResponse(
|
| 439 |
+
content={
|
| 440 |
+
"created": int(time.time()),
|
| 441 |
+
"data": data,
|
| 442 |
+
"usage": {
|
| 443 |
+
"total_tokens": 0,
|
| 444 |
+
"input_tokens": 0,
|
| 445 |
+
"output_tokens": 0,
|
| 446 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 447 |
+
},
|
| 448 |
+
}
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
__all__ = ["router"]
|
app/api/v1/models.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Models API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
from app.services.grok.services.model import ModelService
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
router = APIRouter(tags=["Models"])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.get("/models")
|
| 14 |
+
async def list_models():
|
| 15 |
+
"""OpenAI 兼容 models 列表接口"""
|
| 16 |
+
data = [
|
| 17 |
+
{
|
| 18 |
+
"id": m.model_id,
|
| 19 |
+
"object": "model",
|
| 20 |
+
"created": 0,
|
| 21 |
+
"owned_by": "grok2api@chenyme",
|
| 22 |
+
}
|
| 23 |
+
for m in ModelService.list()
|
| 24 |
+
]
|
| 25 |
+
return {"object": "list", "data": data}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = ["router"]
|
app/api/v1/public_api/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public API router (public_key protected)."""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends
|
| 4 |
+
|
| 5 |
+
from app.api.v1.chat import router as chat_router
|
| 6 |
+
from app.api.v1.public_api.imagine import router as imagine_router
|
| 7 |
+
from app.api.v1.public_api.video import router as video_router
|
| 8 |
+
from app.api.v1.public_api.voice import router as voice_router
|
| 9 |
+
from app.core.auth import verify_public_key
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
|
| 13 |
+
router.include_router(chat_router, dependencies=[Depends(verify_public_key)])
|
| 14 |
+
router.include_router(imagine_router)
|
| 15 |
+
router.include_router(video_router)
|
| 16 |
+
router.include_router(voice_router)
|
| 17 |
+
|
| 18 |
+
__all__ = ["router"]
|
app/api/v1/public_api/imagine.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Optional, List, Dict, Any
|
| 5 |
+
|
| 6 |
+
import orjson
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
|
| 8 |
+
from fastapi.responses import StreamingResponse
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
from app.core.auth import verify_public_key, get_public_api_key, is_public_enabled
|
| 12 |
+
from app.core.config import get_config
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
from app.api.v1.image import resolve_aspect_ratio
|
| 15 |
+
from app.services.grok.services.image import ImageGenerationService
|
| 16 |
+
from app.services.grok.services.model import ModelService
|
| 17 |
+
from app.services.token.manager import get_token_manager
|
| 18 |
+
|
| 19 |
+
router = APIRouter()
|
| 20 |
+
|
| 21 |
+
IMAGINE_SESSION_TTL = 600
|
| 22 |
+
_IMAGINE_SESSIONS: dict[str, dict] = {}
|
| 23 |
+
_IMAGINE_SESSIONS_LOCK = asyncio.Lock()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def _clean_sessions(now: float) -> None:
|
| 27 |
+
expired = [
|
| 28 |
+
key
|
| 29 |
+
for key, info in _IMAGINE_SESSIONS.items()
|
| 30 |
+
if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL
|
| 31 |
+
]
|
| 32 |
+
for key in expired:
|
| 33 |
+
_IMAGINE_SESSIONS.pop(key, None)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _parse_sse_chunk(chunk: str) -> Optional[Dict[str, Any]]:
|
| 37 |
+
if not chunk:
|
| 38 |
+
return None
|
| 39 |
+
event = None
|
| 40 |
+
data_lines: List[str] = []
|
| 41 |
+
for raw in str(chunk).splitlines():
|
| 42 |
+
line = raw.strip()
|
| 43 |
+
if not line:
|
| 44 |
+
continue
|
| 45 |
+
if line.startswith("event:"):
|
| 46 |
+
event = line[6:].strip()
|
| 47 |
+
continue
|
| 48 |
+
if line.startswith("data:"):
|
| 49 |
+
data_lines.append(line[5:].strip())
|
| 50 |
+
if not data_lines:
|
| 51 |
+
return None
|
| 52 |
+
data_str = "\n".join(data_lines)
|
| 53 |
+
if data_str == "[DONE]":
|
| 54 |
+
return None
|
| 55 |
+
try:
|
| 56 |
+
payload = orjson.loads(data_str)
|
| 57 |
+
except orjson.JSONDecodeError:
|
| 58 |
+
return None
|
| 59 |
+
if event and isinstance(payload, dict) and "type" not in payload:
|
| 60 |
+
payload["type"] = event
|
| 61 |
+
return payload
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
async def _new_session(prompt: str, aspect_ratio: str, nsfw: Optional[bool]) -> str:
|
| 65 |
+
task_id = uuid.uuid4().hex
|
| 66 |
+
now = time.time()
|
| 67 |
+
async with _IMAGINE_SESSIONS_LOCK:
|
| 68 |
+
await _clean_sessions(now)
|
| 69 |
+
_IMAGINE_SESSIONS[task_id] = {
|
| 70 |
+
"prompt": prompt,
|
| 71 |
+
"aspect_ratio": aspect_ratio,
|
| 72 |
+
"nsfw": nsfw,
|
| 73 |
+
"created_at": now,
|
| 74 |
+
}
|
| 75 |
+
return task_id
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
async def _get_session(task_id: str) -> Optional[dict]:
|
| 79 |
+
if not task_id:
|
| 80 |
+
return None
|
| 81 |
+
now = time.time()
|
| 82 |
+
async with _IMAGINE_SESSIONS_LOCK:
|
| 83 |
+
await _clean_sessions(now)
|
| 84 |
+
info = _IMAGINE_SESSIONS.get(task_id)
|
| 85 |
+
if not info:
|
| 86 |
+
return None
|
| 87 |
+
created_at = float(info.get("created_at") or 0)
|
| 88 |
+
if now - created_at > IMAGINE_SESSION_TTL:
|
| 89 |
+
_IMAGINE_SESSIONS.pop(task_id, None)
|
| 90 |
+
return None
|
| 91 |
+
return dict(info)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
async def _drop_session(task_id: str) -> None:
|
| 95 |
+
if not task_id:
|
| 96 |
+
return
|
| 97 |
+
async with _IMAGINE_SESSIONS_LOCK:
|
| 98 |
+
_IMAGINE_SESSIONS.pop(task_id, None)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
async def _drop_sessions(task_ids: List[str]) -> int:
|
| 102 |
+
if not task_ids:
|
| 103 |
+
return 0
|
| 104 |
+
removed = 0
|
| 105 |
+
async with _IMAGINE_SESSIONS_LOCK:
|
| 106 |
+
for task_id in task_ids:
|
| 107 |
+
if task_id and task_id in _IMAGINE_SESSIONS:
|
| 108 |
+
_IMAGINE_SESSIONS.pop(task_id, None)
|
| 109 |
+
removed += 1
|
| 110 |
+
return removed
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@router.websocket("/imagine/ws")
|
| 114 |
+
async def public_imagine_ws(websocket: WebSocket):
|
| 115 |
+
session_id = None
|
| 116 |
+
task_id = websocket.query_params.get("task_id")
|
| 117 |
+
if task_id:
|
| 118 |
+
info = await _get_session(task_id)
|
| 119 |
+
if info:
|
| 120 |
+
session_id = task_id
|
| 121 |
+
|
| 122 |
+
ok = True
|
| 123 |
+
if session_id is None:
|
| 124 |
+
public_key = get_public_api_key()
|
| 125 |
+
public_enabled = is_public_enabled()
|
| 126 |
+
if not public_key:
|
| 127 |
+
ok = public_enabled
|
| 128 |
+
else:
|
| 129 |
+
key = websocket.query_params.get("public_key")
|
| 130 |
+
ok = key == public_key
|
| 131 |
+
|
| 132 |
+
if not ok:
|
| 133 |
+
await websocket.close(code=1008)
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
await websocket.accept()
|
| 137 |
+
stop_event = asyncio.Event()
|
| 138 |
+
run_task: Optional[asyncio.Task] = None
|
| 139 |
+
|
| 140 |
+
async def _send(payload: dict) -> bool:
|
| 141 |
+
try:
|
| 142 |
+
await websocket.send_text(orjson.dumps(payload).decode())
|
| 143 |
+
return True
|
| 144 |
+
except Exception:
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
async def _stop_run():
|
| 148 |
+
nonlocal run_task
|
| 149 |
+
stop_event.set()
|
| 150 |
+
if run_task and not run_task.done():
|
| 151 |
+
run_task.cancel()
|
| 152 |
+
try:
|
| 153 |
+
await run_task
|
| 154 |
+
except Exception:
|
| 155 |
+
pass
|
| 156 |
+
run_task = None
|
| 157 |
+
stop_event.clear()
|
| 158 |
+
|
| 159 |
+
async def _run(prompt: str, aspect_ratio: str, nsfw: Optional[bool]):
|
| 160 |
+
model_id = "grok-imagine-1.0"
|
| 161 |
+
model_info = ModelService.get(model_id)
|
| 162 |
+
if not model_info or not model_info.is_image:
|
| 163 |
+
await _send(
|
| 164 |
+
{
|
| 165 |
+
"type": "error",
|
| 166 |
+
"message": "Image model is not available.",
|
| 167 |
+
"code": "model_not_supported",
|
| 168 |
+
}
|
| 169 |
+
)
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
token_mgr = await get_token_manager()
|
| 173 |
+
run_id = uuid.uuid4().hex
|
| 174 |
+
|
| 175 |
+
await _send(
|
| 176 |
+
{
|
| 177 |
+
"type": "status",
|
| 178 |
+
"status": "running",
|
| 179 |
+
"prompt": prompt,
|
| 180 |
+
"aspect_ratio": aspect_ratio,
|
| 181 |
+
"run_id": run_id,
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
while not stop_event.is_set():
|
| 186 |
+
try:
|
| 187 |
+
await token_mgr.reload_if_stale()
|
| 188 |
+
token = None
|
| 189 |
+
for pool_name in ModelService.pool_candidates_for_model(
|
| 190 |
+
model_info.model_id
|
| 191 |
+
):
|
| 192 |
+
token = token_mgr.get_token(pool_name)
|
| 193 |
+
if token:
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
if not token:
|
| 197 |
+
await _send(
|
| 198 |
+
{
|
| 199 |
+
"type": "error",
|
| 200 |
+
"message": "No available tokens. Please try again later.",
|
| 201 |
+
"code": "rate_limit_exceeded",
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
await asyncio.sleep(2)
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
result = await ImageGenerationService().generate(
|
| 208 |
+
token_mgr=token_mgr,
|
| 209 |
+
token=token,
|
| 210 |
+
model_info=model_info,
|
| 211 |
+
prompt=prompt,
|
| 212 |
+
n=6,
|
| 213 |
+
response_format="b64_json",
|
| 214 |
+
size="1024x1024",
|
| 215 |
+
aspect_ratio=aspect_ratio,
|
| 216 |
+
stream=True,
|
| 217 |
+
enable_nsfw=nsfw,
|
| 218 |
+
)
|
| 219 |
+
if result.stream:
|
| 220 |
+
async for chunk in result.data:
|
| 221 |
+
payload = _parse_sse_chunk(chunk)
|
| 222 |
+
if not payload:
|
| 223 |
+
continue
|
| 224 |
+
if isinstance(payload, dict):
|
| 225 |
+
payload.setdefault("run_id", run_id)
|
| 226 |
+
await _send(payload)
|
| 227 |
+
else:
|
| 228 |
+
images = [img for img in result.data if img and img != "error"]
|
| 229 |
+
if images:
|
| 230 |
+
for img_b64 in images:
|
| 231 |
+
await _send(
|
| 232 |
+
{
|
| 233 |
+
"type": "image",
|
| 234 |
+
"b64_json": img_b64,
|
| 235 |
+
"created_at": int(time.time() * 1000),
|
| 236 |
+
"aspect_ratio": aspect_ratio,
|
| 237 |
+
"run_id": run_id,
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
await _send(
|
| 242 |
+
{
|
| 243 |
+
"type": "error",
|
| 244 |
+
"message": "Image generation returned empty data.",
|
| 245 |
+
"code": "empty_image",
|
| 246 |
+
}
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
except asyncio.CancelledError:
|
| 250 |
+
break
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.warning(f"Imagine stream error: {e}")
|
| 253 |
+
await _send(
|
| 254 |
+
{
|
| 255 |
+
"type": "error",
|
| 256 |
+
"message": str(e),
|
| 257 |
+
"code": "internal_error",
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
await asyncio.sleep(1.5)
|
| 261 |
+
|
| 262 |
+
await _send({"type": "status", "status": "stopped", "run_id": run_id})
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
while True:
|
| 266 |
+
try:
|
| 267 |
+
raw = await websocket.receive_text()
|
| 268 |
+
except (RuntimeError, WebSocketDisconnect):
|
| 269 |
+
break
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
payload = orjson.loads(raw)
|
| 273 |
+
except Exception:
|
| 274 |
+
await _send(
|
| 275 |
+
{
|
| 276 |
+
"type": "error",
|
| 277 |
+
"message": "Invalid message format.",
|
| 278 |
+
"code": "invalid_payload",
|
| 279 |
+
}
|
| 280 |
+
)
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
action = payload.get("type")
|
| 284 |
+
if action == "start":
|
| 285 |
+
prompt = str(payload.get("prompt") or "").strip()
|
| 286 |
+
if not prompt:
|
| 287 |
+
await _send(
|
| 288 |
+
{
|
| 289 |
+
"type": "error",
|
| 290 |
+
"message": "Prompt cannot be empty.",
|
| 291 |
+
"code": "invalid_prompt",
|
| 292 |
+
}
|
| 293 |
+
)
|
| 294 |
+
continue
|
| 295 |
+
aspect_ratio = resolve_aspect_ratio(
|
| 296 |
+
str(payload.get("aspect_ratio") or "2:3").strip() or "2:3"
|
| 297 |
+
)
|
| 298 |
+
nsfw = payload.get("nsfw")
|
| 299 |
+
if nsfw is not None:
|
| 300 |
+
nsfw = bool(nsfw)
|
| 301 |
+
await _stop_run()
|
| 302 |
+
run_task = asyncio.create_task(_run(prompt, aspect_ratio, nsfw))
|
| 303 |
+
elif action == "stop":
|
| 304 |
+
await _stop_run()
|
| 305 |
+
else:
|
| 306 |
+
await _send(
|
| 307 |
+
{
|
| 308 |
+
"type": "error",
|
| 309 |
+
"message": "Unknown action.",
|
| 310 |
+
"code": "invalid_action",
|
| 311 |
+
}
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
except WebSocketDisconnect:
|
| 315 |
+
logger.debug("WebSocket disconnected by client")
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.warning(f"WebSocket error: {e}")
|
| 318 |
+
finally:
|
| 319 |
+
await _stop_run()
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
from starlette.websockets import WebSocketState
|
| 323 |
+
if websocket.client_state == WebSocketState.CONNECTED:
|
| 324 |
+
await websocket.close(code=1000, reason="Server closing connection")
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logger.debug(f"WebSocket close ignored: {e}")
|
| 327 |
+
if session_id:
|
| 328 |
+
await _drop_session(session_id)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@router.get("/imagine/sse")
|
| 332 |
+
async def public_imagine_sse(
|
| 333 |
+
request: Request,
|
| 334 |
+
task_id: str = Query(""),
|
| 335 |
+
prompt: str = Query(""),
|
| 336 |
+
aspect_ratio: str = Query("2:3"),
|
| 337 |
+
):
|
| 338 |
+
"""Imagine 图片瀑布流(SSE 兜底)"""
|
| 339 |
+
session = None
|
| 340 |
+
if task_id:
|
| 341 |
+
session = await _get_session(task_id)
|
| 342 |
+
if not session:
|
| 343 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 344 |
+
else:
|
| 345 |
+
public_key = get_public_api_key()
|
| 346 |
+
public_enabled = is_public_enabled()
|
| 347 |
+
if not public_key:
|
| 348 |
+
if not public_enabled:
|
| 349 |
+
raise HTTPException(status_code=401, detail="Public access is disabled")
|
| 350 |
+
else:
|
| 351 |
+
key = request.query_params.get("public_key")
|
| 352 |
+
if key != public_key:
|
| 353 |
+
raise HTTPException(status_code=401, detail="Invalid authentication token")
|
| 354 |
+
|
| 355 |
+
if session:
|
| 356 |
+
prompt = str(session.get("prompt") or "").strip()
|
| 357 |
+
ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3"
|
| 358 |
+
nsfw = session.get("nsfw")
|
| 359 |
+
else:
|
| 360 |
+
prompt = (prompt or "").strip()
|
| 361 |
+
if not prompt:
|
| 362 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
| 363 |
+
ratio = str(aspect_ratio or "2:3").strip() or "2:3"
|
| 364 |
+
ratio = resolve_aspect_ratio(ratio)
|
| 365 |
+
nsfw = request.query_params.get("nsfw")
|
| 366 |
+
if nsfw is not None:
|
| 367 |
+
nsfw = str(nsfw).lower() in ("1", "true", "yes", "on")
|
| 368 |
+
|
| 369 |
+
async def event_stream():
|
| 370 |
+
try:
|
| 371 |
+
model_id = "grok-imagine-1.0"
|
| 372 |
+
model_info = ModelService.get(model_id)
|
| 373 |
+
if not model_info or not model_info.is_image:
|
| 374 |
+
yield (
|
| 375 |
+
f"data: {orjson.dumps({'type': 'error', 'message': 'Image model is not available.', 'code': 'model_not_supported'}).decode()}\n\n"
|
| 376 |
+
)
|
| 377 |
+
return
|
| 378 |
+
|
| 379 |
+
token_mgr = await get_token_manager()
|
| 380 |
+
sequence = 0
|
| 381 |
+
run_id = uuid.uuid4().hex
|
| 382 |
+
|
| 383 |
+
yield (
|
| 384 |
+
f"data: {orjson.dumps({'type': 'status', 'status': 'running', 'prompt': prompt, 'aspect_ratio': ratio, 'run_id': run_id}).decode()}\n\n"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
while True:
|
| 388 |
+
if await request.is_disconnected():
|
| 389 |
+
break
|
| 390 |
+
if task_id:
|
| 391 |
+
session_alive = await _get_session(task_id)
|
| 392 |
+
if not session_alive:
|
| 393 |
+
break
|
| 394 |
+
|
| 395 |
+
try:
|
| 396 |
+
await token_mgr.reload_if_stale()
|
| 397 |
+
token = None
|
| 398 |
+
for pool_name in ModelService.pool_candidates_for_model(
|
| 399 |
+
model_info.model_id
|
| 400 |
+
):
|
| 401 |
+
token = token_mgr.get_token(pool_name)
|
| 402 |
+
if token:
|
| 403 |
+
break
|
| 404 |
+
|
| 405 |
+
if not token:
|
| 406 |
+
yield (
|
| 407 |
+
f"data: {orjson.dumps({'type': 'error', 'message': 'No available tokens. Please try again later.', 'code': 'rate_limit_exceeded'}).decode()}\n\n"
|
| 408 |
+
)
|
| 409 |
+
await asyncio.sleep(2)
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
result = await ImageGenerationService().generate(
|
| 413 |
+
token_mgr=token_mgr,
|
| 414 |
+
token=token,
|
| 415 |
+
model_info=model_info,
|
| 416 |
+
prompt=prompt,
|
| 417 |
+
n=6,
|
| 418 |
+
response_format="b64_json",
|
| 419 |
+
size="1024x1024",
|
| 420 |
+
aspect_ratio=ratio,
|
| 421 |
+
stream=True,
|
| 422 |
+
enable_nsfw=nsfw,
|
| 423 |
+
)
|
| 424 |
+
if result.stream:
|
| 425 |
+
async for chunk in result.data:
|
| 426 |
+
payload = _parse_sse_chunk(chunk)
|
| 427 |
+
if not payload:
|
| 428 |
+
continue
|
| 429 |
+
if isinstance(payload, dict):
|
| 430 |
+
payload.setdefault("run_id", run_id)
|
| 431 |
+
yield f"data: {orjson.dumps(payload).decode()}\n\n"
|
| 432 |
+
else:
|
| 433 |
+
images = [img for img in result.data if img and img != "error"]
|
| 434 |
+
if images:
|
| 435 |
+
for img_b64 in images:
|
| 436 |
+
sequence += 1
|
| 437 |
+
payload = {
|
| 438 |
+
"type": "image",
|
| 439 |
+
"b64_json": img_b64,
|
| 440 |
+
"sequence": sequence,
|
| 441 |
+
"created_at": int(time.time() * 1000),
|
| 442 |
+
"aspect_ratio": ratio,
|
| 443 |
+
"run_id": run_id,
|
| 444 |
+
}
|
| 445 |
+
yield f"data: {orjson.dumps(payload).decode()}\n\n"
|
| 446 |
+
else:
|
| 447 |
+
yield (
|
| 448 |
+
f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n"
|
| 449 |
+
)
|
| 450 |
+
except asyncio.CancelledError:
|
| 451 |
+
break
|
| 452 |
+
except Exception as e:
|
| 453 |
+
logger.warning(f"Imagine SSE error: {e}")
|
| 454 |
+
yield (
|
| 455 |
+
f"data: {orjson.dumps({'type': 'error', 'message': str(e), 'code': 'internal_error'}).decode()}\n\n"
|
| 456 |
+
)
|
| 457 |
+
await asyncio.sleep(1.5)
|
| 458 |
+
|
| 459 |
+
yield (
|
| 460 |
+
f"data: {orjson.dumps({'type': 'status', 'status': 'stopped', 'run_id': run_id}).decode()}\n\n"
|
| 461 |
+
)
|
| 462 |
+
finally:
|
| 463 |
+
if task_id:
|
| 464 |
+
await _drop_session(task_id)
|
| 465 |
+
|
| 466 |
+
return StreamingResponse(
|
| 467 |
+
event_stream(),
|
| 468 |
+
media_type="text/event-stream",
|
| 469 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
@router.get("/imagine/config")
|
| 474 |
+
async def public_imagine_config():
|
| 475 |
+
return {
|
| 476 |
+
"final_min_bytes": int(get_config("image.final_min_bytes") or 0),
|
| 477 |
+
"medium_min_bytes": int(get_config("image.medium_min_bytes") or 0),
|
| 478 |
+
"nsfw": bool(get_config("image.nsfw")),
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class ImagineStartRequest(BaseModel):
|
| 483 |
+
prompt: str
|
| 484 |
+
aspect_ratio: Optional[str] = "2:3"
|
| 485 |
+
nsfw: Optional[bool] = None
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@router.post("/imagine/start", dependencies=[Depends(verify_public_key)])
|
| 489 |
+
async def public_imagine_start(data: ImagineStartRequest):
|
| 490 |
+
prompt = (data.prompt or "").strip()
|
| 491 |
+
if not prompt:
|
| 492 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
| 493 |
+
ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3")
|
| 494 |
+
task_id = await _new_session(prompt, ratio, data.nsfw)
|
| 495 |
+
return {"task_id": task_id, "aspect_ratio": ratio}
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class ImagineStopRequest(BaseModel):
|
| 499 |
+
task_ids: List[str]
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@router.post("/imagine/stop", dependencies=[Depends(verify_public_key)])
|
| 503 |
+
async def public_imagine_stop(data: ImagineStopRequest):
|
| 504 |
+
removed = await _drop_sessions(data.task_ids or [])
|
| 505 |
+
return {"status": "success", "removed": removed}
|
app/api/v1/public_api/video.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Optional, List, Dict, Any
|
| 5 |
+
|
| 6 |
+
import orjson
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
| 8 |
+
from fastapi.responses import StreamingResponse
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
from app.core.auth import verify_public_key
|
| 12 |
+
from app.core.logger import logger
|
| 13 |
+
from app.services.grok.services.video import VideoService
|
| 14 |
+
from app.services.grok.services.model import ModelService
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
VIDEO_SESSION_TTL = 600
|
| 19 |
+
_VIDEO_SESSIONS: dict[str, dict] = {}
|
| 20 |
+
_VIDEO_SESSIONS_LOCK = asyncio.Lock()
|
| 21 |
+
|
| 22 |
+
_VIDEO_RATIO_MAP = {
|
| 23 |
+
"1280x720": "16:9",
|
| 24 |
+
"720x1280": "9:16",
|
| 25 |
+
"1792x1024": "3:2",
|
| 26 |
+
"1024x1792": "2:3",
|
| 27 |
+
"1024x1024": "1:1",
|
| 28 |
+
"16:9": "16:9",
|
| 29 |
+
"9:16": "9:16",
|
| 30 |
+
"3:2": "3:2",
|
| 31 |
+
"2:3": "2:3",
|
| 32 |
+
"1:1": "1:1",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def _clean_sessions(now: float) -> None:
|
| 37 |
+
expired = [
|
| 38 |
+
key
|
| 39 |
+
for key, info in _VIDEO_SESSIONS.items()
|
| 40 |
+
if now - float(info.get("created_at") or 0) > VIDEO_SESSION_TTL
|
| 41 |
+
]
|
| 42 |
+
for key in expired:
|
| 43 |
+
_VIDEO_SESSIONS.pop(key, None)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
async def _new_session(
|
| 47 |
+
prompt: str,
|
| 48 |
+
aspect_ratio: str,
|
| 49 |
+
video_length: int,
|
| 50 |
+
resolution_name: str,
|
| 51 |
+
preset: str,
|
| 52 |
+
image_url: Optional[str],
|
| 53 |
+
reasoning_effort: Optional[str],
|
| 54 |
+
) -> str:
|
| 55 |
+
task_id = uuid.uuid4().hex
|
| 56 |
+
now = time.time()
|
| 57 |
+
async with _VIDEO_SESSIONS_LOCK:
|
| 58 |
+
await _clean_sessions(now)
|
| 59 |
+
_VIDEO_SESSIONS[task_id] = {
|
| 60 |
+
"prompt": prompt,
|
| 61 |
+
"aspect_ratio": aspect_ratio,
|
| 62 |
+
"video_length": video_length,
|
| 63 |
+
"resolution_name": resolution_name,
|
| 64 |
+
"preset": preset,
|
| 65 |
+
"image_url": image_url,
|
| 66 |
+
"reasoning_effort": reasoning_effort,
|
| 67 |
+
"created_at": now,
|
| 68 |
+
}
|
| 69 |
+
return task_id
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def _get_session(task_id: str) -> Optional[dict]:
|
| 73 |
+
if not task_id:
|
| 74 |
+
return None
|
| 75 |
+
now = time.time()
|
| 76 |
+
async with _VIDEO_SESSIONS_LOCK:
|
| 77 |
+
await _clean_sessions(now)
|
| 78 |
+
info = _VIDEO_SESSIONS.get(task_id)
|
| 79 |
+
if not info:
|
| 80 |
+
return None
|
| 81 |
+
created_at = float(info.get("created_at") or 0)
|
| 82 |
+
if now - created_at > VIDEO_SESSION_TTL:
|
| 83 |
+
_VIDEO_SESSIONS.pop(task_id, None)
|
| 84 |
+
return None
|
| 85 |
+
return dict(info)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
async def _drop_session(task_id: str) -> None:
|
| 89 |
+
if not task_id:
|
| 90 |
+
return
|
| 91 |
+
async with _VIDEO_SESSIONS_LOCK:
|
| 92 |
+
_VIDEO_SESSIONS.pop(task_id, None)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
async def _drop_sessions(task_ids: List[str]) -> int:
|
| 96 |
+
if not task_ids:
|
| 97 |
+
return 0
|
| 98 |
+
removed = 0
|
| 99 |
+
async with _VIDEO_SESSIONS_LOCK:
|
| 100 |
+
for task_id in task_ids:
|
| 101 |
+
if task_id and task_id in _VIDEO_SESSIONS:
|
| 102 |
+
_VIDEO_SESSIONS.pop(task_id, None)
|
| 103 |
+
removed += 1
|
| 104 |
+
return removed
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _normalize_ratio(value: Optional[str]) -> str:
|
| 108 |
+
raw = (value or "").strip()
|
| 109 |
+
return _VIDEO_RATIO_MAP.get(raw, "")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _validate_image_url(image_url: str) -> None:
|
| 113 |
+
value = (image_url or "").strip()
|
| 114 |
+
if not value:
|
| 115 |
+
return
|
| 116 |
+
if value.startswith("data:"):
|
| 117 |
+
return
|
| 118 |
+
if value.startswith("http://") or value.startswith("https://"):
|
| 119 |
+
return
|
| 120 |
+
raise HTTPException(
|
| 121 |
+
status_code=400,
|
| 122 |
+
detail="image_url must be a URL or data URI (data:<mime>;base64,...)",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class VideoStartRequest(BaseModel):
|
| 127 |
+
prompt: str
|
| 128 |
+
aspect_ratio: Optional[str] = "3:2"
|
| 129 |
+
video_length: Optional[int] = 6
|
| 130 |
+
resolution_name: Optional[str] = "480p"
|
| 131 |
+
preset: Optional[str] = "normal"
|
| 132 |
+
image_url: Optional[str] = None
|
| 133 |
+
reasoning_effort: Optional[str] = None
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@router.post("/video/start", dependencies=[Depends(verify_public_key)])
|
| 137 |
+
async def public_video_start(data: VideoStartRequest):
|
| 138 |
+
prompt = (data.prompt or "").strip()
|
| 139 |
+
if not prompt:
|
| 140 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
| 141 |
+
|
| 142 |
+
aspect_ratio = _normalize_ratio(data.aspect_ratio)
|
| 143 |
+
if not aspect_ratio:
|
| 144 |
+
raise HTTPException(
|
| 145 |
+
status_code=400,
|
| 146 |
+
detail="aspect_ratio must be one of ['16:9','9:16','3:2','2:3','1:1']",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
video_length = int(data.video_length or 6)
|
| 150 |
+
if video_length not in (6, 10, 15):
|
| 151 |
+
raise HTTPException(
|
| 152 |
+
status_code=400, detail="video_length must be 6, 10, or 15 seconds"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
resolution_name = str(data.resolution_name or "480p")
|
| 156 |
+
if resolution_name not in ("480p", "720p"):
|
| 157 |
+
raise HTTPException(
|
| 158 |
+
status_code=400,
|
| 159 |
+
detail="resolution_name must be one of ['480p','720p']",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
preset = str(data.preset or "normal")
|
| 163 |
+
if preset not in ("fun", "normal", "spicy", "custom"):
|
| 164 |
+
raise HTTPException(
|
| 165 |
+
status_code=400,
|
| 166 |
+
detail="preset must be one of ['fun','normal','spicy','custom']",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
image_url = (data.image_url or "").strip() or None
|
| 170 |
+
if image_url:
|
| 171 |
+
_validate_image_url(image_url)
|
| 172 |
+
|
| 173 |
+
reasoning_effort = (data.reasoning_effort or "").strip() or None
|
| 174 |
+
if reasoning_effort:
|
| 175 |
+
allowed = {"none", "minimal", "low", "medium", "high", "xhigh"}
|
| 176 |
+
if reasoning_effort not in allowed:
|
| 177 |
+
raise HTTPException(
|
| 178 |
+
status_code=400,
|
| 179 |
+
detail=f"reasoning_effort must be one of {sorted(allowed)}",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
task_id = await _new_session(
|
| 183 |
+
prompt,
|
| 184 |
+
aspect_ratio,
|
| 185 |
+
video_length,
|
| 186 |
+
resolution_name,
|
| 187 |
+
preset,
|
| 188 |
+
image_url,
|
| 189 |
+
reasoning_effort,
|
| 190 |
+
)
|
| 191 |
+
return {"task_id": task_id, "aspect_ratio": aspect_ratio}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@router.get("/video/sse")
|
| 195 |
+
async def public_video_sse(request: Request, task_id: str = Query("")):
|
| 196 |
+
session = await _get_session(task_id)
|
| 197 |
+
if not session:
|
| 198 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 199 |
+
|
| 200 |
+
prompt = str(session.get("prompt") or "").strip()
|
| 201 |
+
aspect_ratio = str(session.get("aspect_ratio") or "3:2")
|
| 202 |
+
video_length = int(session.get("video_length") or 6)
|
| 203 |
+
resolution_name = str(session.get("resolution_name") or "480p")
|
| 204 |
+
preset = str(session.get("preset") or "normal")
|
| 205 |
+
image_url = session.get("image_url")
|
| 206 |
+
reasoning_effort = session.get("reasoning_effort")
|
| 207 |
+
|
| 208 |
+
async def event_stream():
|
| 209 |
+
try:
|
| 210 |
+
model_id = "grok-imagine-1.0-video"
|
| 211 |
+
model_info = ModelService.get(model_id)
|
| 212 |
+
if not model_info or not model_info.is_video:
|
| 213 |
+
payload = {
|
| 214 |
+
"error": "Video model is not available.",
|
| 215 |
+
"code": "model_not_supported",
|
| 216 |
+
}
|
| 217 |
+
yield f"data: {orjson.dumps(payload).decode()}\n\n"
|
| 218 |
+
yield "data: [DONE]\n\n"
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
if image_url:
|
| 222 |
+
messages: List[Dict[str, Any]] = [
|
| 223 |
+
{
|
| 224 |
+
"role": "user",
|
| 225 |
+
"content": [
|
| 226 |
+
{"type": "text", "text": prompt},
|
| 227 |
+
{"type": "image_url", "image_url": {"url": image_url}},
|
| 228 |
+
],
|
| 229 |
+
}
|
| 230 |
+
]
|
| 231 |
+
else:
|
| 232 |
+
messages = [{"role": "user", "content": prompt}]
|
| 233 |
+
|
| 234 |
+
stream = await VideoService.completions(
|
| 235 |
+
model_id,
|
| 236 |
+
messages,
|
| 237 |
+
stream=True,
|
| 238 |
+
reasoning_effort=reasoning_effort,
|
| 239 |
+
aspect_ratio=aspect_ratio,
|
| 240 |
+
video_length=video_length,
|
| 241 |
+
resolution=resolution_name,
|
| 242 |
+
preset=preset,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
async for chunk in stream:
|
| 246 |
+
if await request.is_disconnected():
|
| 247 |
+
break
|
| 248 |
+
yield chunk
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.warning(f"Public video SSE error: {e}")
|
| 251 |
+
payload = {"error": str(e), "code": "internal_error"}
|
| 252 |
+
yield f"data: {orjson.dumps(payload).decode()}\n\n"
|
| 253 |
+
yield "data: [DONE]\n\n"
|
| 254 |
+
finally:
|
| 255 |
+
await _drop_session(task_id)
|
| 256 |
+
|
| 257 |
+
return StreamingResponse(
|
| 258 |
+
event_stream(),
|
| 259 |
+
media_type="text/event-stream",
|
| 260 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class VideoStopRequest(BaseModel):
|
| 265 |
+
task_ids: List[str]
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@router.post("/video/stop", dependencies=[Depends(verify_public_key)])
|
| 269 |
+
async def public_video_stop(data: VideoStopRequest):
|
| 270 |
+
removed = await _drop_sessions(data.task_ids or [])
|
| 271 |
+
return {"status": "success", "removed": removed}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
__all__ = ["router"]
|
app/api/v1/public_api/voice.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
from app.core.auth import verify_public_key
|
| 5 |
+
from app.core.exceptions import AppException
|
| 6 |
+
from app.services.grok.services.voice import VoiceService
|
| 7 |
+
from app.services.token.manager import get_token_manager
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VoiceTokenResponse(BaseModel):
|
| 13 |
+
token: str
|
| 14 |
+
url: str
|
| 15 |
+
participant_name: str = ""
|
| 16 |
+
room_name: str = ""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@router.get(
|
| 20 |
+
"/voice/token",
|
| 21 |
+
dependencies=[Depends(verify_public_key)],
|
| 22 |
+
response_model=VoiceTokenResponse,
|
| 23 |
+
)
|
| 24 |
+
async def public_voice_token(
|
| 25 |
+
voice: str = "ara",
|
| 26 |
+
personality: str = "assistant",
|
| 27 |
+
speed: float = 1.0,
|
| 28 |
+
):
|
| 29 |
+
"""获取 Grok Voice Mode (LiveKit) Token"""
|
| 30 |
+
token_mgr = await get_token_manager()
|
| 31 |
+
sso_token = None
|
| 32 |
+
for pool_name in ("ssoBasic", "ssoSuper"):
|
| 33 |
+
sso_token = token_mgr.get_token(pool_name)
|
| 34 |
+
if sso_token:
|
| 35 |
+
break
|
| 36 |
+
|
| 37 |
+
if not sso_token:
|
| 38 |
+
raise AppException(
|
| 39 |
+
"No available tokens for voice mode",
|
| 40 |
+
code="no_token",
|
| 41 |
+
status_code=503,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
service = VoiceService()
|
| 45 |
+
try:
|
| 46 |
+
data = await service.get_token(
|
| 47 |
+
token=sso_token,
|
| 48 |
+
voice=voice,
|
| 49 |
+
personality=personality,
|
| 50 |
+
speed=speed,
|
| 51 |
+
)
|
| 52 |
+
token = data.get("token")
|
| 53 |
+
if not token:
|
| 54 |
+
raise AppException(
|
| 55 |
+
"Upstream returned no voice token",
|
| 56 |
+
code="upstream_error",
|
| 57 |
+
status_code=502,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return VoiceTokenResponse(
|
| 61 |
+
token=token,
|
| 62 |
+
url="wss://livekit.grok.com",
|
| 63 |
+
participant_name="",
|
| 64 |
+
room_name="",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
if isinstance(e, AppException):
|
| 69 |
+
raise
|
| 70 |
+
raise AppException(
|
| 71 |
+
f"Voice token error: {str(e)}",
|
| 72 |
+
code="voice_error",
|
| 73 |
+
status_code=500,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@router.get("/verify", dependencies=[Depends(verify_public_key)])
|
| 78 |
+
async def public_verify_api():
|
| 79 |
+
"""验证 Public Key"""
|
| 80 |
+
return {"status": "success"}
|
app/api/v1/response.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Responses API 路由 (OpenAI compatible).
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter
|
| 8 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
from app.core.exceptions import ValidationException
|
| 12 |
+
from app.services.grok.services.responses import ResponsesService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
router = APIRouter(tags=["Responses"])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ResponseCreateRequest(BaseModel):
|
| 19 |
+
model: str = Field(..., description="Model name")
|
| 20 |
+
input: Optional[Any] = Field(None, description="Input content")
|
| 21 |
+
instructions: Optional[str] = Field(None, description="System instructions")
|
| 22 |
+
stream: Optional[bool] = Field(False, description="Stream response")
|
| 23 |
+
max_output_tokens: Optional[int] = Field(None, description="Max output tokens")
|
| 24 |
+
temperature: Optional[float] = Field(None, description="Sampling temperature")
|
| 25 |
+
top_p: Optional[float] = Field(None, description="Nucleus sampling")
|
| 26 |
+
tools: Optional[List[Dict[str, Any]]] = Field(None, description="Tool definitions")
|
| 27 |
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice")
|
| 28 |
+
parallel_tool_calls: Optional[bool] = Field(True, description="Allow parallel tool calls")
|
| 29 |
+
reasoning: Optional[Dict[str, Any]] = Field(None, description="Reasoning options")
|
| 30 |
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Metadata")
|
| 31 |
+
user: Optional[str] = Field(None, description="User identifier")
|
| 32 |
+
store: Optional[bool] = Field(None, description="Store response")
|
| 33 |
+
previous_response_id: Optional[str] = Field(None, description="Previous response id")
|
| 34 |
+
truncation: Optional[str] = Field(None, description="Truncation behavior")
|
| 35 |
+
|
| 36 |
+
class Config:
|
| 37 |
+
extra = "allow"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.post("/responses")
|
| 41 |
+
async def create_response(request: ResponseCreateRequest):
|
| 42 |
+
if not request.model:
|
| 43 |
+
raise ValidationException(message="model is required", param="model", code="invalid_request_error")
|
| 44 |
+
|
| 45 |
+
if request.input is None:
|
| 46 |
+
raise ValidationException(message="input is required", param="input", code="invalid_request_error")
|
| 47 |
+
|
| 48 |
+
reasoning_effort = None
|
| 49 |
+
if isinstance(request.reasoning, dict):
|
| 50 |
+
reasoning_effort = request.reasoning.get("effort") or request.reasoning.get("reasoning_effort")
|
| 51 |
+
|
| 52 |
+
result = await ResponsesService.create(
|
| 53 |
+
model=request.model,
|
| 54 |
+
input_value=request.input,
|
| 55 |
+
instructions=request.instructions,
|
| 56 |
+
stream=bool(request.stream),
|
| 57 |
+
temperature=request.temperature,
|
| 58 |
+
top_p=request.top_p,
|
| 59 |
+
tools=request.tools,
|
| 60 |
+
tool_choice=request.tool_choice,
|
| 61 |
+
parallel_tool_calls=request.parallel_tool_calls,
|
| 62 |
+
reasoning_effort=reasoning_effort,
|
| 63 |
+
max_output_tokens=request.max_output_tokens,
|
| 64 |
+
metadata=request.metadata,
|
| 65 |
+
user=request.user,
|
| 66 |
+
store=request.store,
|
| 67 |
+
previous_response_id=request.previous_response_id,
|
| 68 |
+
truncation=request.truncation,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if request.stream:
|
| 72 |
+
return StreamingResponse(
|
| 73 |
+
result,
|
| 74 |
+
media_type="text/event-stream",
|
| 75 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return JSONResponse(content=result)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
__all__ = ["router"]
|
app/api/v1/video.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO:Video Generation API 路由
|
| 3 |
+
"""
|
app/core/auth.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API 认证模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
from typing import Optional, Iterable
|
| 7 |
+
from fastapi import HTTPException, status, Security
|
| 8 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 9 |
+
|
| 10 |
+
from app.core.config import get_config
|
| 11 |
+
|
| 12 |
+
DEFAULT_API_KEY = ""
|
| 13 |
+
DEFAULT_APP_KEY = "grok2api"
|
| 14 |
+
DEFAULT_PUBLIC_KEY = ""
|
| 15 |
+
DEFAULT_PUBLIC_ENABLED = False
|
| 16 |
+
|
| 17 |
+
# 定义 Bearer Scheme
|
| 18 |
+
security = HTTPBearer(
|
| 19 |
+
auto_error=False,
|
| 20 |
+
scheme_name="API Key",
|
| 21 |
+
description="Enter your API Key in the format: Bearer <key>",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_admin_api_key() -> str:
|
| 26 |
+
"""
|
| 27 |
+
获取后台 API Key。
|
| 28 |
+
|
| 29 |
+
为空时表示不启用后台接口认证。
|
| 30 |
+
"""
|
| 31 |
+
api_key = get_config("app.api_key", DEFAULT_API_KEY)
|
| 32 |
+
return api_key or ""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _normalize_api_keys(value: Optional[object]) -> list[str]:
|
| 36 |
+
if not value:
|
| 37 |
+
return []
|
| 38 |
+
if isinstance(value, str):
|
| 39 |
+
raw = value.strip()
|
| 40 |
+
if not raw:
|
| 41 |
+
return []
|
| 42 |
+
return [part.strip() for part in raw.split(",") if part.strip()]
|
| 43 |
+
if isinstance(value, Iterable):
|
| 44 |
+
keys: list[str] = []
|
| 45 |
+
for item in value:
|
| 46 |
+
if not item:
|
| 47 |
+
continue
|
| 48 |
+
if isinstance(item, str):
|
| 49 |
+
stripped = item.strip()
|
| 50 |
+
if stripped:
|
| 51 |
+
keys.append(stripped)
|
| 52 |
+
return keys
|
| 53 |
+
return []
|
| 54 |
+
|
| 55 |
+
def get_app_key() -> str:
|
| 56 |
+
"""
|
| 57 |
+
获取 App Key(后台管理密码)。
|
| 58 |
+
"""
|
| 59 |
+
app_key = get_config("app.app_key", DEFAULT_APP_KEY)
|
| 60 |
+
return app_key or ""
|
| 61 |
+
|
| 62 |
+
def get_public_api_key() -> str:
|
| 63 |
+
"""
|
| 64 |
+
获取 Public API Key。
|
| 65 |
+
|
| 66 |
+
为空时表示不启用 public 接口认证。
|
| 67 |
+
"""
|
| 68 |
+
public_key = get_config("app.public_key", DEFAULT_PUBLIC_KEY)
|
| 69 |
+
return public_key or ""
|
| 70 |
+
|
| 71 |
+
def is_public_enabled() -> bool:
|
| 72 |
+
"""
|
| 73 |
+
是否开启 public 功能入口。
|
| 74 |
+
"""
|
| 75 |
+
return bool(get_config("app.public_enabled", DEFAULT_PUBLIC_ENABLED))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _hash_public_key(key: str) -> str:
|
| 79 |
+
"""计算 public_key 的 SHA-256 哈希,与前端 hashPublicKey 保持一致。"""
|
| 80 |
+
return hashlib.sha256(f"grok2api-public:{key}".encode()).hexdigest()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _match_public_key(credentials: str, public_key: str) -> bool:
|
| 84 |
+
"""检查凭证是否匹配 public_key(支持原始值和 public-<sha256> 哈希格式)。"""
|
| 85 |
+
if not public_key:
|
| 86 |
+
return False
|
| 87 |
+
normalized = public_key.strip()
|
| 88 |
+
if not normalized:
|
| 89 |
+
return False
|
| 90 |
+
if credentials == normalized:
|
| 91 |
+
return True
|
| 92 |
+
if credentials.startswith("public-"):
|
| 93 |
+
expected_hash = _hash_public_key(normalized)
|
| 94 |
+
if credentials == f"public-{expected_hash}":
|
| 95 |
+
return True
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def verify_api_key(
|
| 100 |
+
auth: Optional[HTTPAuthorizationCredentials] = Security(security),
|
| 101 |
+
) -> Optional[str]:
|
| 102 |
+
"""
|
| 103 |
+
验证 Bearer Token
|
| 104 |
+
|
| 105 |
+
如果 config.toml 中未配置 api_key,则不启用认证。
|
| 106 |
+
"""
|
| 107 |
+
api_key = get_admin_api_key()
|
| 108 |
+
api_keys = _normalize_api_keys(api_key)
|
| 109 |
+
if not api_keys:
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
if not auth:
|
| 113 |
+
raise HTTPException(
|
| 114 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 115 |
+
detail="Missing authentication token",
|
| 116 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 标准 api_key 验证
|
| 120 |
+
if auth.credentials in api_keys:
|
| 121 |
+
return auth.credentials
|
| 122 |
+
|
| 123 |
+
raise HTTPException(
|
| 124 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 125 |
+
detail="Invalid authentication token",
|
| 126 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
async def verify_app_key(
|
| 131 |
+
auth: Optional[HTTPAuthorizationCredentials] = Security(security),
|
| 132 |
+
) -> Optional[str]:
|
| 133 |
+
"""
|
| 134 |
+
验证后台登录密钥(app_key)。
|
| 135 |
+
|
| 136 |
+
app_key 必须配置,否则拒绝登录。
|
| 137 |
+
"""
|
| 138 |
+
app_key = get_app_key()
|
| 139 |
+
|
| 140 |
+
if not app_key:
|
| 141 |
+
raise HTTPException(
|
| 142 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 143 |
+
detail="App key is not configured",
|
| 144 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if not auth:
|
| 148 |
+
raise HTTPException(
|
| 149 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 150 |
+
detail="Missing authentication token",
|
| 151 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if auth.credentials != app_key:
|
| 155 |
+
raise HTTPException(
|
| 156 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 157 |
+
detail="Invalid authentication token",
|
| 158 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return auth.credentials
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
async def verify_public_key(
|
| 165 |
+
auth: Optional[HTTPAuthorizationCredentials] = Security(security),
|
| 166 |
+
) -> Optional[str]:
|
| 167 |
+
"""
|
| 168 |
+
验证 Public Key(public 接口使用)。
|
| 169 |
+
|
| 170 |
+
默认不公开,需配置 public_key 才能访问;若开启 public_enabled 且未配置 public_key,则放开访问。
|
| 171 |
+
"""
|
| 172 |
+
public_key = get_public_api_key()
|
| 173 |
+
public_enabled = is_public_enabled()
|
| 174 |
+
|
| 175 |
+
if not public_key:
|
| 176 |
+
if public_enabled:
|
| 177 |
+
return None
|
| 178 |
+
raise HTTPException(
|
| 179 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 180 |
+
detail="Public access is disabled",
|
| 181 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not auth:
|
| 185 |
+
raise HTTPException(
|
| 186 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 187 |
+
detail="Missing authentication token",
|
| 188 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if _match_public_key(auth.credentials, public_key):
|
| 192 |
+
return auth.credentials
|
| 193 |
+
|
| 194 |
+
raise HTTPException(
|
| 195 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 196 |
+
detail="Invalid authentication token",
|
| 197 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 198 |
+
)
|
app/core/batch.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch utilities.
|
| 3 |
+
|
| 4 |
+
- run_batch: generic batch concurrency runner
|
| 5 |
+
- BatchTask: SSE task manager for admin batch operations
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
|
| 15 |
+
T = TypeVar("T")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def run_batch(
|
| 19 |
+
items: List[str],
|
| 20 |
+
worker: Callable[[str], Awaitable[T]],
|
| 21 |
+
*,
|
| 22 |
+
batch_size: int = 50,
|
| 23 |
+
task: Optional["BatchTask"] = None,
|
| 24 |
+
on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
|
| 25 |
+
should_cancel: Optional[Callable[[], bool]] = None,
|
| 26 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 27 |
+
"""
|
| 28 |
+
分批并发执行,单项失败不影响整体
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
items: 待处理项列表
|
| 32 |
+
worker: 异步处理函数
|
| 33 |
+
batch_size: 每批大小
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
{item: {"ok": bool, "data": ..., "error": ...}}
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
batch_size = int(batch_size)
|
| 40 |
+
except Exception:
|
| 41 |
+
batch_size = 50
|
| 42 |
+
|
| 43 |
+
batch_size = max(1, batch_size)
|
| 44 |
+
|
| 45 |
+
async def _one(item: str) -> tuple[str, dict]:
|
| 46 |
+
if (should_cancel and should_cancel()) or (task and task.cancelled):
|
| 47 |
+
return item, {"ok": False, "error": "cancelled", "cancelled": True}
|
| 48 |
+
try:
|
| 49 |
+
data = await worker(item)
|
| 50 |
+
result = {"ok": True, "data": data}
|
| 51 |
+
if task:
|
| 52 |
+
task.record(True)
|
| 53 |
+
if on_item:
|
| 54 |
+
try:
|
| 55 |
+
await on_item(item, result)
|
| 56 |
+
except Exception:
|
| 57 |
+
pass
|
| 58 |
+
return item, result
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.warning(f"Batch item failed: {item[:16]}... - {e}")
|
| 61 |
+
result = {"ok": False, "error": str(e)}
|
| 62 |
+
if task:
|
| 63 |
+
task.record(False, error=str(e))
|
| 64 |
+
if on_item:
|
| 65 |
+
try:
|
| 66 |
+
await on_item(item, result)
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
return item, result
|
| 70 |
+
|
| 71 |
+
results: Dict[str, dict] = {}
|
| 72 |
+
|
| 73 |
+
# 分批执行,避免一次性创建所有 task
|
| 74 |
+
for i in range(0, len(items), batch_size):
|
| 75 |
+
if (should_cancel and should_cancel()) or (task and task.cancelled):
|
| 76 |
+
break
|
| 77 |
+
chunk = items[i : i + batch_size]
|
| 78 |
+
pairs = await asyncio.gather(*(_one(x) for x in chunk))
|
| 79 |
+
results.update(dict(pairs))
|
| 80 |
+
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class BatchTask:
|
| 85 |
+
def __init__(self, total: int):
|
| 86 |
+
self.id = uuid.uuid4().hex
|
| 87 |
+
self.total = int(total)
|
| 88 |
+
self.processed = 0
|
| 89 |
+
self.ok = 0
|
| 90 |
+
self.fail = 0
|
| 91 |
+
self.status = "running"
|
| 92 |
+
self.warning: Optional[str] = None
|
| 93 |
+
self.result: Optional[Dict[str, Any]] = None
|
| 94 |
+
self.error: Optional[str] = None
|
| 95 |
+
self.created_at = time.time()
|
| 96 |
+
self._queues: List[asyncio.Queue] = []
|
| 97 |
+
self._final_event: Optional[Dict[str, Any]] = None
|
| 98 |
+
self.cancelled = False
|
| 99 |
+
|
| 100 |
+
def snapshot(self) -> Dict[str, Any]:
|
| 101 |
+
return {
|
| 102 |
+
"task_id": self.id,
|
| 103 |
+
"status": self.status,
|
| 104 |
+
"total": self.total,
|
| 105 |
+
"processed": self.processed,
|
| 106 |
+
"ok": self.ok,
|
| 107 |
+
"fail": self.fail,
|
| 108 |
+
"warning": self.warning,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def attach(self) -> asyncio.Queue:
|
| 112 |
+
q: asyncio.Queue = asyncio.Queue(maxsize=200)
|
| 113 |
+
self._queues.append(q)
|
| 114 |
+
return q
|
| 115 |
+
|
| 116 |
+
def detach(self, q: asyncio.Queue) -> None:
|
| 117 |
+
if q in self._queues:
|
| 118 |
+
self._queues.remove(q)
|
| 119 |
+
|
| 120 |
+
def _publish(self, event: Dict[str, Any]) -> None:
|
| 121 |
+
for q in list(self._queues):
|
| 122 |
+
try:
|
| 123 |
+
q.put_nowait(event)
|
| 124 |
+
except Exception:
|
| 125 |
+
# Drop if queue is full or closed
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def record(
|
| 129 |
+
self, ok: bool, *, item: Any = None, detail: Any = None, error: str = ""
|
| 130 |
+
) -> None:
|
| 131 |
+
self.processed += 1
|
| 132 |
+
if ok:
|
| 133 |
+
self.ok += 1
|
| 134 |
+
else:
|
| 135 |
+
self.fail += 1
|
| 136 |
+
event: Dict[str, Any] = {
|
| 137 |
+
"type": "progress",
|
| 138 |
+
"task_id": self.id,
|
| 139 |
+
"total": self.total,
|
| 140 |
+
"processed": self.processed,
|
| 141 |
+
"ok": self.ok,
|
| 142 |
+
"fail": self.fail,
|
| 143 |
+
}
|
| 144 |
+
if item is not None:
|
| 145 |
+
event["item"] = item
|
| 146 |
+
if detail is not None:
|
| 147 |
+
event["detail"] = detail
|
| 148 |
+
if error:
|
| 149 |
+
event["error"] = error
|
| 150 |
+
self._publish(event)
|
| 151 |
+
|
| 152 |
+
def finish(self, result: Dict[str, Any], *, warning: Optional[str] = None) -> None:
|
| 153 |
+
self.status = "done"
|
| 154 |
+
self.result = result
|
| 155 |
+
self.warning = warning
|
| 156 |
+
event = {
|
| 157 |
+
"type": "done",
|
| 158 |
+
"task_id": self.id,
|
| 159 |
+
"total": self.total,
|
| 160 |
+
"processed": self.processed,
|
| 161 |
+
"ok": self.ok,
|
| 162 |
+
"fail": self.fail,
|
| 163 |
+
"warning": self.warning,
|
| 164 |
+
"result": result,
|
| 165 |
+
}
|
| 166 |
+
self._final_event = event
|
| 167 |
+
self._publish(event)
|
| 168 |
+
|
| 169 |
+
def fail_task(self, error: str) -> None:
|
| 170 |
+
self.status = "error"
|
| 171 |
+
self.error = error
|
| 172 |
+
event = {
|
| 173 |
+
"type": "error",
|
| 174 |
+
"task_id": self.id,
|
| 175 |
+
"total": self.total,
|
| 176 |
+
"processed": self.processed,
|
| 177 |
+
"ok": self.ok,
|
| 178 |
+
"fail": self.fail,
|
| 179 |
+
"error": error,
|
| 180 |
+
}
|
| 181 |
+
self._final_event = event
|
| 182 |
+
self._publish(event)
|
| 183 |
+
|
| 184 |
+
def cancel(self) -> None:
|
| 185 |
+
self.cancelled = True
|
| 186 |
+
|
| 187 |
+
def finish_cancelled(self) -> None:
|
| 188 |
+
self.status = "cancelled"
|
| 189 |
+
event = {
|
| 190 |
+
"type": "cancelled",
|
| 191 |
+
"task_id": self.id,
|
| 192 |
+
"total": self.total,
|
| 193 |
+
"processed": self.processed,
|
| 194 |
+
"ok": self.ok,
|
| 195 |
+
"fail": self.fail,
|
| 196 |
+
}
|
| 197 |
+
self._final_event = event
|
| 198 |
+
self._publish(event)
|
| 199 |
+
|
| 200 |
+
def final_event(self) -> Optional[Dict[str, Any]]:
|
| 201 |
+
return self._final_event
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
_TASKS: Dict[str, BatchTask] = {}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def create_task(total: int) -> BatchTask:
|
| 208 |
+
task = BatchTask(total)
|
| 209 |
+
_TASKS[task.id] = task
|
| 210 |
+
return task
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_task(task_id: str) -> Optional[BatchTask]:
|
| 214 |
+
return _TASKS.get(task_id)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def delete_task(task_id: str) -> None:
|
| 218 |
+
_TASKS.pop(task_id, None)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
async def expire_task(task_id: str, delay: int = 300) -> None:
|
| 222 |
+
await asyncio.sleep(delay)
|
| 223 |
+
delete_task(task_id)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
__all__ = [
|
| 227 |
+
"run_batch",
|
| 228 |
+
"BatchTask",
|
| 229 |
+
"create_task",
|
| 230 |
+
"get_task",
|
| 231 |
+
"delete_task",
|
| 232 |
+
"expire_task",
|
| 233 |
+
]
|
app/core/config.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
配置管理
|
| 3 |
+
|
| 4 |
+
- config.toml: 运行时配置
|
| 5 |
+
- config.defaults.toml: 默认配置基线
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict
|
| 11 |
+
import tomllib
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
|
| 15 |
+
DEFAULT_CONFIG_FILE = Path(__file__).parent.parent.parent / "config.defaults.toml"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
| 19 |
+
"""深度合并字典: override 覆盖 base."""
|
| 20 |
+
if not isinstance(base, dict):
|
| 21 |
+
return deepcopy(override) if isinstance(override, dict) else deepcopy(base)
|
| 22 |
+
|
| 23 |
+
result = deepcopy(base)
|
| 24 |
+
if not isinstance(override, dict):
|
| 25 |
+
return result
|
| 26 |
+
|
| 27 |
+
for key, val in override.items():
|
| 28 |
+
if isinstance(val, dict) and isinstance(result.get(key), dict):
|
| 29 |
+
result[key] = _deep_merge(result[key], val)
|
| 30 |
+
else:
|
| 31 |
+
result[key] = val
|
| 32 |
+
return result
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _migrate_deprecated_config(
|
| 36 |
+
config: Dict[str, Any], valid_sections: set
|
| 37 |
+
) -> tuple[Dict[str, Any], set]:
|
| 38 |
+
"""
|
| 39 |
+
迁移废弃的配置节到新配置结构
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
(迁移后的配置, 废弃的配置节集合)
|
| 43 |
+
"""
|
| 44 |
+
# 配置映射规则:旧配置 -> 新配置
|
| 45 |
+
MIGRATION_MAP = {
|
| 46 |
+
# grok.* -> 对应的新配置节
|
| 47 |
+
"grok.temporary": "app.temporary",
|
| 48 |
+
"grok.disable_memory": "app.disable_memory",
|
| 49 |
+
"grok.stream": "app.stream",
|
| 50 |
+
"grok.thinking": "app.thinking",
|
| 51 |
+
"grok.dynamic_statsig": "app.dynamic_statsig",
|
| 52 |
+
"grok.filter_tags": "app.filter_tags",
|
| 53 |
+
"grok.timeout": "voice.timeout",
|
| 54 |
+
"grok.base_proxy_url": "proxy.base_proxy_url",
|
| 55 |
+
"grok.asset_proxy_url": "proxy.asset_proxy_url",
|
| 56 |
+
"network.base_proxy_url": "proxy.base_proxy_url",
|
| 57 |
+
"network.asset_proxy_url": "proxy.asset_proxy_url",
|
| 58 |
+
"grok.cf_clearance": "proxy.cf_clearance",
|
| 59 |
+
"grok.browser": "proxy.browser",
|
| 60 |
+
"grok.user_agent": "proxy.user_agent",
|
| 61 |
+
"security.cf_clearance": "proxy.cf_clearance",
|
| 62 |
+
"security.browser": "proxy.browser",
|
| 63 |
+
"security.user_agent": "proxy.user_agent",
|
| 64 |
+
"grok.max_retry": "retry.max_retry",
|
| 65 |
+
"grok.retry_status_codes": "retry.retry_status_codes",
|
| 66 |
+
"grok.retry_backoff_base": "retry.retry_backoff_base",
|
| 67 |
+
"grok.retry_backoff_factor": "retry.retry_backoff_factor",
|
| 68 |
+
"grok.retry_backoff_max": "retry.retry_backoff_max",
|
| 69 |
+
"grok.retry_budget": "retry.retry_budget",
|
| 70 |
+
"grok.video_idle_timeout": "video.stream_timeout",
|
| 71 |
+
"grok.image_ws_nsfw": "image.nsfw",
|
| 72 |
+
"grok.image_ws_blocked_seconds": "image.final_timeout",
|
| 73 |
+
"grok.image_ws_final_min_bytes": "image.final_min_bytes",
|
| 74 |
+
"grok.image_ws_medium_min_bytes": "image.medium_min_bytes",
|
| 75 |
+
# legacy sections
|
| 76 |
+
"network.base_proxy_url": "proxy.base_proxy_url",
|
| 77 |
+
"network.asset_proxy_url": "proxy.asset_proxy_url",
|
| 78 |
+
"network.timeout": [
|
| 79 |
+
"chat.timeout",
|
| 80 |
+
"image.timeout",
|
| 81 |
+
"video.timeout",
|
| 82 |
+
"voice.timeout",
|
| 83 |
+
],
|
| 84 |
+
"security.cf_clearance": "proxy.cf_clearance",
|
| 85 |
+
"security.browser": "proxy.browser",
|
| 86 |
+
"security.user_agent": "proxy.user_agent",
|
| 87 |
+
"timeout.stream_idle_timeout": [
|
| 88 |
+
"chat.stream_timeout",
|
| 89 |
+
"image.stream_timeout",
|
| 90 |
+
"video.stream_timeout",
|
| 91 |
+
],
|
| 92 |
+
"timeout.video_idle_timeout": "video.stream_timeout",
|
| 93 |
+
"image.image_ws_nsfw": "image.nsfw",
|
| 94 |
+
"image.image_ws_blocked_seconds": "image.final_timeout",
|
| 95 |
+
"image.image_ws_final_min_bytes": "image.final_min_bytes",
|
| 96 |
+
"image.image_ws_medium_min_bytes": "image.medium_min_bytes",
|
| 97 |
+
"performance.assets_max_concurrent": [
|
| 98 |
+
"asset.upload_concurrent",
|
| 99 |
+
"asset.download_concurrent",
|
| 100 |
+
"asset.list_concurrent",
|
| 101 |
+
"asset.delete_concurrent",
|
| 102 |
+
],
|
| 103 |
+
"performance.assets_delete_batch_size": "asset.delete_batch_size",
|
| 104 |
+
"performance.assets_batch_size": "asset.list_batch_size",
|
| 105 |
+
"performance.media_max_concurrent": ["chat.concurrent", "video.concurrent"],
|
| 106 |
+
"performance.usage_max_concurrent": "usage.concurrent",
|
| 107 |
+
"performance.usage_batch_size": "usage.batch_size",
|
| 108 |
+
"performance.nsfw_max_concurrent": "nsfw.concurrent",
|
| 109 |
+
"performance.nsfw_batch_size": "nsfw.batch_size",
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
deprecated_sections = set(config.keys()) - valid_sections
|
| 113 |
+
if not deprecated_sections:
|
| 114 |
+
return config, set()
|
| 115 |
+
|
| 116 |
+
result = {k: deepcopy(v) for k, v in config.items() if k in valid_sections}
|
| 117 |
+
migrated_count = 0
|
| 118 |
+
|
| 119 |
+
# 处理废弃配置节或旧配置键
|
| 120 |
+
for old_section, old_values in config.items():
|
| 121 |
+
if not isinstance(old_values, dict):
|
| 122 |
+
continue
|
| 123 |
+
for old_key, old_value in old_values.items():
|
| 124 |
+
old_path = f"{old_section}.{old_key}"
|
| 125 |
+
new_paths = MIGRATION_MAP.get(old_path)
|
| 126 |
+
if not new_paths:
|
| 127 |
+
continue
|
| 128 |
+
if isinstance(new_paths, str):
|
| 129 |
+
new_paths = [new_paths]
|
| 130 |
+
for new_path in new_paths:
|
| 131 |
+
try:
|
| 132 |
+
new_section, new_key = new_path.split(".", 1)
|
| 133 |
+
if new_section not in result:
|
| 134 |
+
result[new_section] = {}
|
| 135 |
+
if new_key not in result[new_section]:
|
| 136 |
+
result[new_section][new_key] = old_value
|
| 137 |
+
migrated_count += 1
|
| 138 |
+
logger.debug(
|
| 139 |
+
f"Migrated config: {old_path} -> {new_path} = {old_value}"
|
| 140 |
+
)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.warning(
|
| 143 |
+
f"Skip config migration for {old_path}: {e}"
|
| 144 |
+
)
|
| 145 |
+
continue
|
| 146 |
+
if isinstance(result.get(old_section), dict):
|
| 147 |
+
result[old_section].pop(old_key, None)
|
| 148 |
+
|
| 149 |
+
# 兼容旧 chat.* 配置键迁移到 app.*
|
| 150 |
+
legacy_chat_map = {
|
| 151 |
+
"temporary": "temporary",
|
| 152 |
+
"disable_memory": "disable_memory",
|
| 153 |
+
"stream": "stream",
|
| 154 |
+
"thinking": "thinking",
|
| 155 |
+
"dynamic_statsig": "dynamic_statsig",
|
| 156 |
+
"filter_tags": "filter_tags",
|
| 157 |
+
}
|
| 158 |
+
chat_section = config.get("chat")
|
| 159 |
+
if isinstance(chat_section, dict):
|
| 160 |
+
app_section = result.setdefault("app", {})
|
| 161 |
+
for old_key, new_key in legacy_chat_map.items():
|
| 162 |
+
if old_key in chat_section and new_key not in app_section:
|
| 163 |
+
app_section[new_key] = chat_section[old_key]
|
| 164 |
+
if isinstance(result.get("chat"), dict):
|
| 165 |
+
result["chat"].pop(old_key, None)
|
| 166 |
+
migrated_count += 1
|
| 167 |
+
logger.debug(
|
| 168 |
+
f"Migrated config: chat.{old_key} -> app.{new_key} = {chat_section[old_key]}"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if migrated_count > 0:
|
| 172 |
+
logger.info(
|
| 173 |
+
f"Migrated {migrated_count} config items from deprecated/legacy sections"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return result, deprecated_sections
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _load_defaults() -> Dict[str, Any]:
|
| 180 |
+
"""加载默认配置文件"""
|
| 181 |
+
if not DEFAULT_CONFIG_FILE.exists():
|
| 182 |
+
return {}
|
| 183 |
+
try:
|
| 184 |
+
with DEFAULT_CONFIG_FILE.open("rb") as f:
|
| 185 |
+
return tomllib.load(f)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.warning(f"Failed to load defaults from {DEFAULT_CONFIG_FILE}: {e}")
|
| 188 |
+
return {}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Config:
|
| 192 |
+
"""配置管理器"""
|
| 193 |
+
|
| 194 |
+
_instance = None
|
| 195 |
+
_config = {}
|
| 196 |
+
|
| 197 |
+
def __init__(self):
|
| 198 |
+
self._config = {}
|
| 199 |
+
self._defaults = {}
|
| 200 |
+
self._code_defaults = {}
|
| 201 |
+
self._defaults_loaded = False
|
| 202 |
+
|
| 203 |
+
def register_defaults(self, defaults: Dict[str, Any]):
|
| 204 |
+
"""注册代码中定义的默认值"""
|
| 205 |
+
self._code_defaults = _deep_merge(self._code_defaults, defaults)
|
| 206 |
+
|
| 207 |
+
def _ensure_defaults(self):
|
| 208 |
+
if self._defaults_loaded:
|
| 209 |
+
return
|
| 210 |
+
file_defaults = _load_defaults()
|
| 211 |
+
# 合并文件默认值和代码默认值(代码默认值优先级更低)
|
| 212 |
+
self._defaults = _deep_merge(self._code_defaults, file_defaults)
|
| 213 |
+
self._defaults_loaded = True
|
| 214 |
+
|
| 215 |
+
async def load(self):
|
| 216 |
+
"""显式加载配置"""
|
| 217 |
+
try:
|
| 218 |
+
from app.core.storage import get_storage, LocalStorage
|
| 219 |
+
|
| 220 |
+
self._ensure_defaults()
|
| 221 |
+
|
| 222 |
+
storage = get_storage()
|
| 223 |
+
config_data = await storage.load_config()
|
| 224 |
+
from_remote = True
|
| 225 |
+
|
| 226 |
+
# 从本地 data/config.toml 初始化后端
|
| 227 |
+
if config_data is None:
|
| 228 |
+
local_storage = LocalStorage()
|
| 229 |
+
from_remote = False
|
| 230 |
+
try:
|
| 231 |
+
# 尝试读取本地配置
|
| 232 |
+
config_data = await local_storage.load_config()
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.info(f"Failed to auto-init config from local: {e}")
|
| 235 |
+
config_data = {}
|
| 236 |
+
|
| 237 |
+
config_data = config_data or {}
|
| 238 |
+
|
| 239 |
+
# 检查是否有废弃的配置节
|
| 240 |
+
valid_sections = set(self._defaults.keys())
|
| 241 |
+
config_data, deprecated_sections = _migrate_deprecated_config(
|
| 242 |
+
config_data, valid_sections
|
| 243 |
+
)
|
| 244 |
+
if deprecated_sections:
|
| 245 |
+
logger.info(
|
| 246 |
+
f"Cleaned deprecated config sections: {deprecated_sections}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
merged = _deep_merge(self._defaults, config_data)
|
| 250 |
+
|
| 251 |
+
# 自动回填缺失配置到存储
|
| 252 |
+
# 或迁移了配置后需要更新
|
| 253 |
+
# 保护:当远程存储返回 None 且本地也没有可迁移配置时,不覆盖远程配置,避免误重置。
|
| 254 |
+
has_local_seed = bool(config_data)
|
| 255 |
+
allow_bootstrap_empty_remote = (
|
| 256 |
+
(not from_remote) and has_local_seed
|
| 257 |
+
)
|
| 258 |
+
should_persist = (
|
| 259 |
+
allow_bootstrap_empty_remote
|
| 260 |
+
or (merged != config_data and bool(config_data))
|
| 261 |
+
or deprecated_sections
|
| 262 |
+
)
|
| 263 |
+
if should_persist:
|
| 264 |
+
async with storage.acquire_lock("config_save", timeout=10):
|
| 265 |
+
await storage.save_config(merged)
|
| 266 |
+
if not from_remote and has_local_seed:
|
| 267 |
+
logger.info(
|
| 268 |
+
f"Initialized remote storage ({storage.__class__.__name__}) with config baseline."
|
| 269 |
+
)
|
| 270 |
+
if deprecated_sections:
|
| 271 |
+
logger.info("Configuration automatically migrated and cleaned.")
|
| 272 |
+
elif not from_remote and not has_local_seed:
|
| 273 |
+
logger.warning(
|
| 274 |
+
"Skip persisting defaults: empty config source detected, keep runtime merged config only."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self._config = merged
|
| 278 |
+
except Exception as e:
|
| 279 |
+
logger.error(f"Error loading config: {e}")
|
| 280 |
+
self._config = {}
|
| 281 |
+
|
| 282 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 283 |
+
"""
|
| 284 |
+
获取配置值
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
key: 配置键,格式 "section.key"
|
| 288 |
+
default: 默认值
|
| 289 |
+
"""
|
| 290 |
+
if "." in key:
|
| 291 |
+
try:
|
| 292 |
+
section, attr = key.split(".", 1)
|
| 293 |
+
return self._config.get(section, {}).get(attr, default)
|
| 294 |
+
except (ValueError, AttributeError):
|
| 295 |
+
return default
|
| 296 |
+
|
| 297 |
+
return self._config.get(key, default)
|
| 298 |
+
|
| 299 |
+
async def update(self, new_config: dict):
|
| 300 |
+
"""更新配置"""
|
| 301 |
+
from app.core.storage import get_storage
|
| 302 |
+
|
| 303 |
+
storage = get_storage()
|
| 304 |
+
async with storage.acquire_lock("config_save", timeout=10):
|
| 305 |
+
self._ensure_defaults()
|
| 306 |
+
base = _deep_merge(self._defaults, self._config or {})
|
| 307 |
+
merged = _deep_merge(base, new_config or {})
|
| 308 |
+
await storage.save_config(merged)
|
| 309 |
+
self._config = merged
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# 全局配置实例
|
| 313 |
+
config = Config()
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def get_config(key: str, default: Any = None) -> Any:
|
| 317 |
+
"""获取配置"""
|
| 318 |
+
return config.get(key, default)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def register_defaults(defaults: Dict[str, Any]):
|
| 322 |
+
"""注册默认配置"""
|
| 323 |
+
config.register_defaults(defaults)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
__all__ = ["Config", "config", "get_config", "register_defaults"]
|
app/core/exceptions.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
全局异常处理 - OpenAI 兼容错误格式
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from fastapi import Request, HTTPException
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
from fastapi.exceptions import RequestValidationError
|
| 10 |
+
|
| 11 |
+
from app.core.logger import logger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ============= 错误类型 =============
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ErrorType(str, Enum):
|
| 18 |
+
"""OpenAI 错误类型"""
|
| 19 |
+
|
| 20 |
+
INVALID_REQUEST = "invalid_request_error"
|
| 21 |
+
AUTHENTICATION = "authentication_error"
|
| 22 |
+
PERMISSION = "permission_error"
|
| 23 |
+
NOT_FOUND = "not_found_error"
|
| 24 |
+
RATE_LIMIT = "rate_limit_error"
|
| 25 |
+
SERVER = "server_error"
|
| 26 |
+
SERVICE_UNAVAILABLE = "service_unavailable_error"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ============= 辅助函数 =============
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def error_response(
|
| 33 |
+
message: str,
|
| 34 |
+
error_type: str = ErrorType.INVALID_REQUEST.value,
|
| 35 |
+
param: str = None,
|
| 36 |
+
code: str = None,
|
| 37 |
+
) -> dict:
|
| 38 |
+
"""构建 OpenAI 错误响应"""
|
| 39 |
+
return {
|
| 40 |
+
"error": {"message": message, "type": error_type, "param": param, "code": code}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ============= 异常类 =============
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AppException(Exception):
|
| 48 |
+
"""应用基础异常"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
message: str,
|
| 53 |
+
error_type: str = ErrorType.SERVER.value,
|
| 54 |
+
code: str = None,
|
| 55 |
+
param: str = None,
|
| 56 |
+
status_code: int = 500,
|
| 57 |
+
):
|
| 58 |
+
self.message = message
|
| 59 |
+
self.error_type = error_type
|
| 60 |
+
self.code = code
|
| 61 |
+
self.param = param
|
| 62 |
+
self.status_code = status_code
|
| 63 |
+
super().__init__(message)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ValidationException(AppException):
|
| 67 |
+
"""验证错误"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, message: str, param: str = None, code: str = None):
|
| 70 |
+
super().__init__(
|
| 71 |
+
message=message,
|
| 72 |
+
error_type=ErrorType.INVALID_REQUEST.value,
|
| 73 |
+
code=code or "invalid_value",
|
| 74 |
+
param=param,
|
| 75 |
+
status_code=400,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AuthenticationException(AppException):
|
| 80 |
+
"""认证错误"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, message: str = "Invalid API key"):
|
| 83 |
+
super().__init__(
|
| 84 |
+
message=message,
|
| 85 |
+
error_type=ErrorType.AUTHENTICATION.value,
|
| 86 |
+
code="invalid_api_key",
|
| 87 |
+
status_code=401,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class UpstreamException(AppException):
|
| 92 |
+
"""上游服务错误"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, message: str, details: Any = None):
|
| 95 |
+
super().__init__(
|
| 96 |
+
message=message,
|
| 97 |
+
error_type=ErrorType.SERVER.value,
|
| 98 |
+
code="upstream_error",
|
| 99 |
+
status_code=502,
|
| 100 |
+
)
|
| 101 |
+
self.details = details
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class StreamIdleTimeoutError(Exception):
|
| 105 |
+
"""流空闲超时错误"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, idle_seconds: float):
|
| 108 |
+
self.idle_seconds = idle_seconds
|
| 109 |
+
super().__init__(f"Stream idle timeout after {idle_seconds}s")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ============= 异常处理器 =============
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
async def app_exception_handler(request: Request, exc: AppException) -> JSONResponse:
|
| 116 |
+
"""处理应用异常"""
|
| 117 |
+
logger.warning(f"AppException: {exc.error_type} - {exc.message}")
|
| 118 |
+
|
| 119 |
+
return JSONResponse(
|
| 120 |
+
status_code=exc.status_code,
|
| 121 |
+
content=error_response(
|
| 122 |
+
message=exc.message,
|
| 123 |
+
error_type=exc.error_type,
|
| 124 |
+
param=exc.param,
|
| 125 |
+
code=exc.code,
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
| 131 |
+
"""处理 HTTP 异常"""
|
| 132 |
+
type_map = {
|
| 133 |
+
400: ErrorType.INVALID_REQUEST.value,
|
| 134 |
+
401: ErrorType.AUTHENTICATION.value,
|
| 135 |
+
403: ErrorType.PERMISSION.value,
|
| 136 |
+
404: ErrorType.NOT_FOUND.value,
|
| 137 |
+
429: ErrorType.RATE_LIMIT.value,
|
| 138 |
+
}
|
| 139 |
+
error_type = type_map.get(exc.status_code, ErrorType.SERVER.value)
|
| 140 |
+
|
| 141 |
+
# 默认 code 映射
|
| 142 |
+
code_map = {
|
| 143 |
+
401: "invalid_api_key",
|
| 144 |
+
403: "insufficient_quota",
|
| 145 |
+
404: "model_not_found",
|
| 146 |
+
429: "rate_limit_exceeded",
|
| 147 |
+
}
|
| 148 |
+
code = code_map.get(exc.status_code, None)
|
| 149 |
+
|
| 150 |
+
logger.warning(f"HTTPException: {exc.status_code} - {exc.detail}")
|
| 151 |
+
|
| 152 |
+
return JSONResponse(
|
| 153 |
+
status_code=exc.status_code,
|
| 154 |
+
content=error_response(
|
| 155 |
+
message=str(exc.detail), error_type=error_type, code=code
|
| 156 |
+
),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
async def validation_exception_handler(
|
| 161 |
+
request: Request, exc: RequestValidationError
|
| 162 |
+
) -> JSONResponse:
|
| 163 |
+
"""处理验证错误"""
|
| 164 |
+
errors = exc.errors()
|
| 165 |
+
|
| 166 |
+
if errors:
|
| 167 |
+
first = errors[0]
|
| 168 |
+
loc = first.get("loc", [])
|
| 169 |
+
msg = first.get("msg", "Invalid request")
|
| 170 |
+
code = first.get("type", "invalid_value")
|
| 171 |
+
|
| 172 |
+
# JSON 解析错误
|
| 173 |
+
if code == "json_invalid" or "JSON" in msg:
|
| 174 |
+
message = "Invalid JSON in request body. Please check for trailing commas or syntax errors."
|
| 175 |
+
param = "body"
|
| 176 |
+
else:
|
| 177 |
+
param_parts = [
|
| 178 |
+
str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())
|
| 179 |
+
]
|
| 180 |
+
param = ".".join(param_parts) if param_parts else None
|
| 181 |
+
message = msg
|
| 182 |
+
else:
|
| 183 |
+
param, message, code = None, "Invalid request", "invalid_value"
|
| 184 |
+
|
| 185 |
+
logger.warning(f"ValidationError: {param} - {message}")
|
| 186 |
+
|
| 187 |
+
return JSONResponse(
|
| 188 |
+
status_code=400,
|
| 189 |
+
content=error_response(
|
| 190 |
+
message=message,
|
| 191 |
+
error_type=ErrorType.INVALID_REQUEST.value,
|
| 192 |
+
param=param,
|
| 193 |
+
code=code,
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
| 199 |
+
"""处理未捕获异常"""
|
| 200 |
+
logger.exception(f"Unhandled: {type(exc).__name__}: {str(exc)}")
|
| 201 |
+
|
| 202 |
+
return JSONResponse(
|
| 203 |
+
status_code=500,
|
| 204 |
+
content=error_response(
|
| 205 |
+
message="Internal server error",
|
| 206 |
+
error_type=ErrorType.SERVER.value,
|
| 207 |
+
code="internal_error",
|
| 208 |
+
),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ============= 注册 =============
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def register_exception_handlers(app):
|
| 216 |
+
"""注册异常处理器"""
|
| 217 |
+
app.add_exception_handler(AppException, app_exception_handler)
|
| 218 |
+
app.add_exception_handler(HTTPException, http_exception_handler)
|
| 219 |
+
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
| 220 |
+
app.add_exception_handler(Exception, generic_exception_handler)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
__all__ = [
|
| 224 |
+
"ErrorType",
|
| 225 |
+
"AppException",
|
| 226 |
+
"ValidationException",
|
| 227 |
+
"AuthenticationException",
|
| 228 |
+
"UpstreamException",
|
| 229 |
+
"StreamIdleTimeoutError",
|
| 230 |
+
"error_response",
|
| 231 |
+
"register_exception_handlers",
|
| 232 |
+
]
|
app/core/logger.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
结构化 JSON 日志 - 极简格式
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import traceback
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from loguru import logger
|
| 11 |
+
|
| 12 |
+
# Provide logging.Logger compatibility for legacy calls
|
| 13 |
+
if not hasattr(logger, "isEnabledFor"):
|
| 14 |
+
logger.isEnabledFor = lambda _level: True
|
| 15 |
+
|
| 16 |
+
# 日志目录
|
| 17 |
+
DEFAULT_LOG_DIR = Path(__file__).parent.parent.parent / "logs"
|
| 18 |
+
LOG_DIR = Path(os.getenv("LOG_DIR", str(DEFAULT_LOG_DIR)))
|
| 19 |
+
_LOG_DIR_READY = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _prepare_log_dir() -> bool:
|
| 23 |
+
"""确保日志目录可用"""
|
| 24 |
+
global LOG_DIR, _LOG_DIR_READY
|
| 25 |
+
if _LOG_DIR_READY:
|
| 26 |
+
return True
|
| 27 |
+
try:
|
| 28 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
_LOG_DIR_READY = True
|
| 30 |
+
return True
|
| 31 |
+
except Exception:
|
| 32 |
+
_LOG_DIR_READY = False
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _format_json(record) -> str:
|
| 37 |
+
"""格式化日志"""
|
| 38 |
+
# ISO8601 时间
|
| 39 |
+
time_str = record["time"].strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 40 |
+
tz = record["time"].strftime("%z")
|
| 41 |
+
if tz:
|
| 42 |
+
time_str += tz[:3] + ":" + tz[3:]
|
| 43 |
+
|
| 44 |
+
log_entry = {
|
| 45 |
+
"time": time_str,
|
| 46 |
+
"level": record["level"].name.lower(),
|
| 47 |
+
"msg": record["message"],
|
| 48 |
+
"caller": f"{record['file'].name}:{record['line']}",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# trace 上下文
|
| 52 |
+
extra = record["extra"]
|
| 53 |
+
if extra.get("traceID"):
|
| 54 |
+
log_entry["traceID"] = extra["traceID"]
|
| 55 |
+
if extra.get("spanID"):
|
| 56 |
+
log_entry["spanID"] = extra["spanID"]
|
| 57 |
+
|
| 58 |
+
# 其他 extra 字段
|
| 59 |
+
for key, value in extra.items():
|
| 60 |
+
if key not in ("traceID", "spanID") and not key.startswith("_"):
|
| 61 |
+
log_entry[key] = value
|
| 62 |
+
|
| 63 |
+
# 错误及以上级别添加堆栈跟踪
|
| 64 |
+
if record["level"].no >= 40 and record["exception"]:
|
| 65 |
+
log_entry["stacktrace"] = "".join(
|
| 66 |
+
traceback.format_exception(
|
| 67 |
+
record["exception"].type,
|
| 68 |
+
record["exception"].value,
|
| 69 |
+
record["exception"].traceback,
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return json.dumps(log_entry, ensure_ascii=False)
|
| 74 |
+
|
| 75 |
+
def _env_flag(name: str, default: bool) -> bool:
|
| 76 |
+
raw = os.getenv(name)
|
| 77 |
+
if raw is None:
|
| 78 |
+
return default
|
| 79 |
+
return raw.strip().lower() in ("1", "true", "yes", "on", "y")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _make_json_sink(output):
|
| 83 |
+
"""创建 JSON sink"""
|
| 84 |
+
|
| 85 |
+
def sink(message):
|
| 86 |
+
json_str = _format_json(message.record)
|
| 87 |
+
print(json_str, file=output, flush=True)
|
| 88 |
+
|
| 89 |
+
return sink
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _file_json_sink(message):
|
| 93 |
+
"""写入日志文件"""
|
| 94 |
+
record = message.record
|
| 95 |
+
json_str = _format_json(record)
|
| 96 |
+
log_file = LOG_DIR / f"app_{record['time'].strftime('%Y-%m-%d')}.log"
|
| 97 |
+
with open(log_file, "a", encoding="utf-8") as f:
|
| 98 |
+
f.write(json_str + "\n")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def setup_logging(
|
| 102 |
+
level: str = "DEBUG",
|
| 103 |
+
json_console: bool = True,
|
| 104 |
+
file_logging: bool = True,
|
| 105 |
+
):
|
| 106 |
+
"""设置日志配置"""
|
| 107 |
+
logger.remove()
|
| 108 |
+
file_logging = _env_flag("LOG_FILE_ENABLED", file_logging)
|
| 109 |
+
|
| 110 |
+
# 控制台输出
|
| 111 |
+
if json_console:
|
| 112 |
+
logger.add(
|
| 113 |
+
_make_json_sink(sys.stdout),
|
| 114 |
+
level=level,
|
| 115 |
+
format="{message}",
|
| 116 |
+
colorize=False,
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
logger.add(
|
| 120 |
+
sys.stdout,
|
| 121 |
+
level=level,
|
| 122 |
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{file.name}:{line}</cyan> - <level>{message}</level>",
|
| 123 |
+
colorize=True,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# 文件输出
|
| 127 |
+
if file_logging:
|
| 128 |
+
if _prepare_log_dir():
|
| 129 |
+
logger.add(
|
| 130 |
+
_file_json_sink,
|
| 131 |
+
level=level,
|
| 132 |
+
format="{message}",
|
| 133 |
+
enqueue=True,
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
logger.warning("File logging disabled: no writable log directory.")
|
| 137 |
+
|
| 138 |
+
return logger
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_logger(trace_id: str = "", span_id: str = ""):
|
| 142 |
+
"""获取绑定了 trace 上下文的 logger"""
|
| 143 |
+
bound = {}
|
| 144 |
+
if trace_id:
|
| 145 |
+
bound["traceID"] = trace_id
|
| 146 |
+
if span_id:
|
| 147 |
+
bound["spanID"] = span_id
|
| 148 |
+
return logger.bind(**bound) if bound else logger
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
__all__ = ["logger", "setup_logging", "get_logger", "LOG_DIR"]
|
app/core/response_middleware.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
响应中间件
|
| 3 |
+
Response Middleware
|
| 4 |
+
|
| 5 |
+
用于记录请求日志、生成 TraceID 和计算请求耗时
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import uuid
|
| 10 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 11 |
+
from starlette.requests import Request
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResponseLoggerMiddleware(BaseHTTPMiddleware):
|
| 17 |
+
"""
|
| 18 |
+
请求日志/响应追踪中间件
|
| 19 |
+
Request Logging and Response Tracking Middleware
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
async def dispatch(self, request: Request, call_next):
|
| 23 |
+
# 生成请求 ID
|
| 24 |
+
trace_id = str(uuid.uuid4())
|
| 25 |
+
request.state.trace_id = trace_id
|
| 26 |
+
|
| 27 |
+
start_time = time.time()
|
| 28 |
+
path = request.url.path
|
| 29 |
+
|
| 30 |
+
if path.startswith("/static/") or path in (
|
| 31 |
+
"/",
|
| 32 |
+
"/login",
|
| 33 |
+
"/imagine",
|
| 34 |
+
"/voice",
|
| 35 |
+
"/admin",
|
| 36 |
+
"/admin/login",
|
| 37 |
+
"/admin/config",
|
| 38 |
+
"/admin/cache",
|
| 39 |
+
"/admin/token",
|
| 40 |
+
):
|
| 41 |
+
return await call_next(request)
|
| 42 |
+
|
| 43 |
+
# 记录请求信息
|
| 44 |
+
logger.info(
|
| 45 |
+
f"Request: {request.method} {request.url.path}",
|
| 46 |
+
extra={
|
| 47 |
+
"traceID": trace_id,
|
| 48 |
+
"method": request.method,
|
| 49 |
+
"path": request.url.path,
|
| 50 |
+
},
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
response = await call_next(request)
|
| 55 |
+
|
| 56 |
+
# 计算耗时
|
| 57 |
+
duration = (time.time() - start_time) * 1000
|
| 58 |
+
|
| 59 |
+
# 记录响应信息
|
| 60 |
+
logger.info(
|
| 61 |
+
f"Response: {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)",
|
| 62 |
+
extra={
|
| 63 |
+
"traceID": trace_id,
|
| 64 |
+
"method": request.method,
|
| 65 |
+
"path": request.url.path,
|
| 66 |
+
"status": response.status_code,
|
| 67 |
+
"duration_ms": round(duration, 2),
|
| 68 |
+
},
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return response
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
duration = (time.time() - start_time) * 1000
|
| 75 |
+
logger.error(
|
| 76 |
+
f"Response Error: {request.method} {request.url.path} - {str(e)} ({duration:.2f}ms)",
|
| 77 |
+
extra={
|
| 78 |
+
"traceID": trace_id,
|
| 79 |
+
"method": request.method,
|
| 80 |
+
"path": request.url.path,
|
| 81 |
+
"duration_ms": round(duration, 2),
|
| 82 |
+
"error": str(e),
|
| 83 |
+
},
|
| 84 |
+
)
|
| 85 |
+
raise e
|
app/core/storage.py
ADDED
|
@@ -0,0 +1,1478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
统一存储服务 (Professional Storage Service)
|
| 3 |
+
支持 Local (TOML), Redis, MySQL, PostgreSQL
|
| 4 |
+
|
| 5 |
+
特性:
|
| 6 |
+
- 全异步 I/O (Async I/O)
|
| 7 |
+
- 连接池管理 (Connection Pooling)
|
| 8 |
+
- 分布式/本地锁 (Distributed/Local Locking)
|
| 9 |
+
- 内存优化 (序列化性能优化)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import abc
|
| 13 |
+
import os
|
| 14 |
+
import asyncio
|
| 15 |
+
import hashlib
|
| 16 |
+
import time
|
| 17 |
+
import tomllib
|
| 18 |
+
from typing import Any, ClassVar, Dict, Optional
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from enum import Enum
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import fcntl
|
| 24 |
+
except ImportError: # pragma: no cover - non-posix platforms
|
| 25 |
+
fcntl = None
|
| 26 |
+
from contextlib import asynccontextmanager
|
| 27 |
+
|
| 28 |
+
import orjson
|
| 29 |
+
import aiofiles
|
| 30 |
+
from app.core.logger import logger
|
| 31 |
+
|
| 32 |
+
# 数据目录(支持通过环境变量覆盖)
|
| 33 |
+
DEFAULT_DATA_DIR = Path(__file__).parent.parent.parent / "data"
|
| 34 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", str(DEFAULT_DATA_DIR))).expanduser()
|
| 35 |
+
|
| 36 |
+
# 配置文件路径
|
| 37 |
+
CONFIG_FILE = DATA_DIR / "config.toml"
|
| 38 |
+
TOKEN_FILE = DATA_DIR / "token.json"
|
| 39 |
+
LOCK_DIR = DATA_DIR / ".locks"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# JSON 序列化优化助手函数
|
| 43 |
+
def json_dumps(obj: Any) -> str:
|
| 44 |
+
return orjson.dumps(obj).decode("utf-8")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def json_loads(obj: str | bytes) -> Any:
|
| 48 |
+
return orjson.loads(obj)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def json_dumps_sorted(obj: Any) -> str:
|
| 52 |
+
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class StorageError(Exception):
|
| 56 |
+
"""存储服务基础异常"""
|
| 57 |
+
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BaseStorage(abc.ABC):
|
| 62 |
+
"""存储基类"""
|
| 63 |
+
|
| 64 |
+
@abc.abstractmethod
|
| 65 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 66 |
+
"""加载配置"""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
@abc.abstractmethod
|
| 70 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 71 |
+
"""保存配置"""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@abc.abstractmethod
|
| 75 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 76 |
+
"""加载所有 Token"""
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
@abc.abstractmethod
|
| 80 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 81 |
+
"""保存所有 Token"""
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
async def save_tokens_delta(
|
| 85 |
+
self, updated: list[Dict[str, Any]], deleted: Optional[list[str]] = None
|
| 86 |
+
):
|
| 87 |
+
"""增量保存 Token(默认回退到全量保存)"""
|
| 88 |
+
existing = await self.load_tokens() or {}
|
| 89 |
+
|
| 90 |
+
deleted_set = set(deleted or [])
|
| 91 |
+
if deleted_set:
|
| 92 |
+
for pool_name, tokens in list(existing.items()):
|
| 93 |
+
if not isinstance(tokens, list):
|
| 94 |
+
continue
|
| 95 |
+
filtered = []
|
| 96 |
+
for item in tokens:
|
| 97 |
+
if isinstance(item, str):
|
| 98 |
+
token_str = item
|
| 99 |
+
elif isinstance(item, dict):
|
| 100 |
+
token_str = item.get("token")
|
| 101 |
+
else:
|
| 102 |
+
token_str = None
|
| 103 |
+
if token_str and token_str in deleted_set:
|
| 104 |
+
continue
|
| 105 |
+
filtered.append(item)
|
| 106 |
+
existing[pool_name] = filtered
|
| 107 |
+
|
| 108 |
+
for item in updated or []:
|
| 109 |
+
if not isinstance(item, dict):
|
| 110 |
+
continue
|
| 111 |
+
pool_name = item.get("pool_name")
|
| 112 |
+
token_str = item.get("token")
|
| 113 |
+
if not pool_name or not token_str:
|
| 114 |
+
continue
|
| 115 |
+
pool_list = existing.setdefault(pool_name, [])
|
| 116 |
+
normalized = {
|
| 117 |
+
k: v
|
| 118 |
+
for k, v in item.items()
|
| 119 |
+
if k not in ("pool_name", "_update_kind")
|
| 120 |
+
}
|
| 121 |
+
replaced = False
|
| 122 |
+
for idx, current in enumerate(pool_list):
|
| 123 |
+
if isinstance(current, str):
|
| 124 |
+
if current == token_str:
|
| 125 |
+
pool_list[idx] = normalized
|
| 126 |
+
replaced = True
|
| 127 |
+
break
|
| 128 |
+
elif isinstance(current, dict) and current.get("token") == token_str:
|
| 129 |
+
pool_list[idx] = normalized
|
| 130 |
+
replaced = True
|
| 131 |
+
break
|
| 132 |
+
if not replaced:
|
| 133 |
+
pool_list.append(normalized)
|
| 134 |
+
|
| 135 |
+
await self.save_tokens(existing)
|
| 136 |
+
|
| 137 |
+
@abc.abstractmethod
|
| 138 |
+
async def close(self):
|
| 139 |
+
"""关闭资源"""
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
@asynccontextmanager
|
| 143 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 144 |
+
"""
|
| 145 |
+
获取锁 (互斥访问)
|
| 146 |
+
用于读写操作的临界区保护
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
name: 锁名称
|
| 150 |
+
timeout: 超时时间 (秒)
|
| 151 |
+
"""
|
| 152 |
+
# 默认空实现,用于 fallback
|
| 153 |
+
yield
|
| 154 |
+
|
| 155 |
+
async def verify_connection(self) -> bool:
|
| 156 |
+
"""健康检查"""
|
| 157 |
+
return True
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LocalStorage(BaseStorage):
|
| 161 |
+
"""
|
| 162 |
+
本地文件存储
|
| 163 |
+
- 使用 aiofiles 进行异步 I/O
|
| 164 |
+
- 使用 asyncio.Lock 进行进程内并发控制
|
| 165 |
+
- 如果需要多进程安全,需要系统级文件锁 (fcntl)
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self):
|
| 169 |
+
self._lock = asyncio.Lock()
|
| 170 |
+
|
| 171 |
+
@asynccontextmanager
|
| 172 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 173 |
+
if fcntl is None:
|
| 174 |
+
try:
|
| 175 |
+
async with asyncio.timeout(timeout):
|
| 176 |
+
async with self._lock:
|
| 177 |
+
yield
|
| 178 |
+
except asyncio.TimeoutError:
|
| 179 |
+
logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
|
| 180 |
+
raise StorageError(f"无法获取锁 '{name}'")
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
lock_path = LOCK_DIR / f"{name}.lock"
|
| 184 |
+
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
| 185 |
+
fd = None
|
| 186 |
+
locked = False
|
| 187 |
+
start = time.monotonic()
|
| 188 |
+
|
| 189 |
+
async with self._lock:
|
| 190 |
+
try:
|
| 191 |
+
fd = open(lock_path, "a+")
|
| 192 |
+
while True:
|
| 193 |
+
try:
|
| 194 |
+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
| 195 |
+
locked = True
|
| 196 |
+
break
|
| 197 |
+
except BlockingIOError:
|
| 198 |
+
if time.monotonic() - start >= timeout:
|
| 199 |
+
raise StorageError(f"无法获取锁 '{name}'")
|
| 200 |
+
await asyncio.sleep(0.05)
|
| 201 |
+
yield
|
| 202 |
+
except StorageError:
|
| 203 |
+
logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
|
| 204 |
+
raise
|
| 205 |
+
finally:
|
| 206 |
+
if fd:
|
| 207 |
+
if locked:
|
| 208 |
+
try:
|
| 209 |
+
fcntl.flock(fd, fcntl.LOCK_UN)
|
| 210 |
+
except Exception:
|
| 211 |
+
pass
|
| 212 |
+
try:
|
| 213 |
+
fd.close()
|
| 214 |
+
except Exception:
|
| 215 |
+
pass
|
| 216 |
+
|
| 217 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 218 |
+
if not CONFIG_FILE.exists():
|
| 219 |
+
return {}
|
| 220 |
+
try:
|
| 221 |
+
async with aiofiles.open(CONFIG_FILE, "rb") as f:
|
| 222 |
+
content = await f.read()
|
| 223 |
+
return tomllib.loads(content.decode("utf-8"))
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"LocalStorage: 加载配置失败: {e}")
|
| 226 |
+
return {}
|
| 227 |
+
|
| 228 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 229 |
+
try:
|
| 230 |
+
lines = []
|
| 231 |
+
for section, items in data.items():
|
| 232 |
+
if not isinstance(items, dict):
|
| 233 |
+
continue
|
| 234 |
+
lines.append(f"[{section}]")
|
| 235 |
+
for key, val in items.items():
|
| 236 |
+
if isinstance(val, bool):
|
| 237 |
+
val_str = "true" if val else "false"
|
| 238 |
+
elif isinstance(val, str):
|
| 239 |
+
escaped = val.replace('"', '\\"')
|
| 240 |
+
val_str = f'"{escaped}"'
|
| 241 |
+
elif isinstance(val, (int, float)):
|
| 242 |
+
val_str = str(val)
|
| 243 |
+
elif isinstance(val, (list, dict)):
|
| 244 |
+
val_str = json_dumps(val)
|
| 245 |
+
else:
|
| 246 |
+
val_str = f'"{str(val)}"'
|
| 247 |
+
lines.append(f"{key} = {val_str}")
|
| 248 |
+
lines.append("")
|
| 249 |
+
|
| 250 |
+
content = "\n".join(lines)
|
| 251 |
+
|
| 252 |
+
CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
async with aiofiles.open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
| 254 |
+
await f.write(content)
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.error(f"LocalStorage: 保存配置失败: {e}")
|
| 257 |
+
raise StorageError(f"保存配置失败: {e}")
|
| 258 |
+
|
| 259 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 260 |
+
if not TOKEN_FILE.exists():
|
| 261 |
+
return {}
|
| 262 |
+
try:
|
| 263 |
+
async with aiofiles.open(TOKEN_FILE, "rb") as f:
|
| 264 |
+
content = await f.read()
|
| 265 |
+
return json_loads(content)
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"LocalStorage: 加载 Token 失败: {e}")
|
| 268 |
+
return {}
|
| 269 |
+
|
| 270 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 271 |
+
try:
|
| 272 |
+
TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
temp_path = TOKEN_FILE.with_suffix(".tmp")
|
| 274 |
+
|
| 275 |
+
# 原子写操作: 写入临时文件 -> 重命名
|
| 276 |
+
async with aiofiles.open(temp_path, "wb") as f:
|
| 277 |
+
await f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
| 278 |
+
|
| 279 |
+
# 使用 os.replace 保证原子性
|
| 280 |
+
os.replace(temp_path, TOKEN_FILE)
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
logger.error(f"LocalStorage: 保存 Token 失败: {e}")
|
| 284 |
+
raise StorageError(f"保存 Token 失败: {e}")
|
| 285 |
+
|
| 286 |
+
async def close(self):
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class RedisStorage(BaseStorage):
|
| 291 |
+
"""
|
| 292 |
+
Redis 存储
|
| 293 |
+
- 使用 redis-py 异步客户端 (自带连接池)
|
| 294 |
+
- 支持分布式锁 (redis.lock)
|
| 295 |
+
- 扁平化数据结构优化性能
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, url: str):
|
| 299 |
+
try:
|
| 300 |
+
from redis import asyncio as aioredis
|
| 301 |
+
except ImportError:
|
| 302 |
+
raise ImportError("需要安装 redis 包: pip install redis")
|
| 303 |
+
|
| 304 |
+
# 显式配置连接池
|
| 305 |
+
# 使用 decode_responses=True 简化字符串处理,但在处理复杂对象时使用 orjson
|
| 306 |
+
self.redis = aioredis.from_url(
|
| 307 |
+
url, decode_responses=True, health_check_interval=30
|
| 308 |
+
)
|
| 309 |
+
self.config_key = "grok2api:config" # Hash: section.key -> value_json
|
| 310 |
+
self.key_pools = "grok2api:pools" # Set: pool_names
|
| 311 |
+
self.prefix_pool_set = "grok2api:pool:" # Set: pool -> token_ids
|
| 312 |
+
self.prefix_token_hash = "grok2api:token:" # Hash: token_id -> token_data
|
| 313 |
+
self.lock_prefix = "grok2api:lock:"
|
| 314 |
+
|
| 315 |
+
@asynccontextmanager
|
| 316 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 317 |
+
# 使用 Redis 分布式锁
|
| 318 |
+
lock_key = f"{self.lock_prefix}{name}"
|
| 319 |
+
lock = self.redis.lock(lock_key, timeout=timeout, blocking_timeout=5)
|
| 320 |
+
acquired = False
|
| 321 |
+
try:
|
| 322 |
+
acquired = await lock.acquire()
|
| 323 |
+
if not acquired:
|
| 324 |
+
raise StorageError(f"RedisStorage: 无法获取锁 '{name}'")
|
| 325 |
+
yield
|
| 326 |
+
finally:
|
| 327 |
+
if acquired:
|
| 328 |
+
try:
|
| 329 |
+
await lock.release()
|
| 330 |
+
except Exception:
|
| 331 |
+
# 锁可能已过期或被意外释放,忽略异常
|
| 332 |
+
pass
|
| 333 |
+
|
| 334 |
+
async def verify_connection(self) -> bool:
|
| 335 |
+
try:
|
| 336 |
+
return await self.redis.ping()
|
| 337 |
+
except Exception:
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 341 |
+
"""从 Redis Hash 加载配置"""
|
| 342 |
+
try:
|
| 343 |
+
raw_data = await self.redis.hgetall(self.config_key)
|
| 344 |
+
if not raw_data:
|
| 345 |
+
return None
|
| 346 |
+
|
| 347 |
+
config = {}
|
| 348 |
+
for composite_key, val_str in raw_data.items():
|
| 349 |
+
if "." not in composite_key:
|
| 350 |
+
continue
|
| 351 |
+
section, key = composite_key.split(".", 1)
|
| 352 |
+
|
| 353 |
+
if section not in config:
|
| 354 |
+
config[section] = {}
|
| 355 |
+
|
| 356 |
+
try:
|
| 357 |
+
val = json_loads(val_str)
|
| 358 |
+
except Exception:
|
| 359 |
+
val = val_str
|
| 360 |
+
config[section][key] = val
|
| 361 |
+
return config
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"RedisStorage: 加载配置失败: {e}")
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 367 |
+
"""保存配置到 Redis Hash"""
|
| 368 |
+
try:
|
| 369 |
+
mapping = {}
|
| 370 |
+
for section, items in data.items():
|
| 371 |
+
if not isinstance(items, dict):
|
| 372 |
+
continue
|
| 373 |
+
for key, val in items.items():
|
| 374 |
+
composite_key = f"{section}.{key}"
|
| 375 |
+
mapping[composite_key] = json_dumps(val)
|
| 376 |
+
|
| 377 |
+
await self.redis.delete(self.config_key)
|
| 378 |
+
if mapping:
|
| 379 |
+
await self.redis.hset(self.config_key, mapping=mapping)
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(f"RedisStorage: 保存配置失败: {e}")
|
| 382 |
+
raise
|
| 383 |
+
|
| 384 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 385 |
+
"""加载所有 Token"""
|
| 386 |
+
try:
|
| 387 |
+
pool_names = await self.redis.smembers(self.key_pools)
|
| 388 |
+
if not pool_names:
|
| 389 |
+
return None
|
| 390 |
+
|
| 391 |
+
pools = {}
|
| 392 |
+
async with self.redis.pipeline() as pipe:
|
| 393 |
+
for pool_name in pool_names:
|
| 394 |
+
# 获取该池下所有 Token ID
|
| 395 |
+
pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
|
| 396 |
+
pool_tokens_res = await pipe.execute()
|
| 397 |
+
|
| 398 |
+
# 收集所有 Token ID 以便批量查询
|
| 399 |
+
all_token_ids = []
|
| 400 |
+
pool_map = {} # pool_name -> list[token_id]
|
| 401 |
+
|
| 402 |
+
for i, pool_name in enumerate(pool_names):
|
| 403 |
+
tids = list(pool_tokens_res[i])
|
| 404 |
+
pool_map[pool_name] = tids
|
| 405 |
+
all_token_ids.extend(tids)
|
| 406 |
+
|
| 407 |
+
if not all_token_ids:
|
| 408 |
+
return {name: [] for name in pool_names}
|
| 409 |
+
|
| 410 |
+
# 批量获取 Token 详情 (Hash)
|
| 411 |
+
async with self.redis.pipeline() as pipe:
|
| 412 |
+
for tid in all_token_ids:
|
| 413 |
+
pipe.hgetall(f"{self.prefix_token_hash}{tid}")
|
| 414 |
+
token_data_list = await pipe.execute()
|
| 415 |
+
|
| 416 |
+
# 重组数据结构
|
| 417 |
+
token_lookup = {}
|
| 418 |
+
for i, tid in enumerate(all_token_ids):
|
| 419 |
+
t_data = token_data_list[i]
|
| 420 |
+
if not t_data:
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
+
# 恢复 tags (JSON -> List)
|
| 424 |
+
if "tags" in t_data:
|
| 425 |
+
try:
|
| 426 |
+
t_data["tags"] = json_loads(t_data["tags"])
|
| 427 |
+
except Exception:
|
| 428 |
+
t_data["tags"] = []
|
| 429 |
+
|
| 430 |
+
# 类型转换 (Redis 返回全 string)
|
| 431 |
+
for int_field in [
|
| 432 |
+
"quota",
|
| 433 |
+
"created_at",
|
| 434 |
+
"use_count",
|
| 435 |
+
"fail_count",
|
| 436 |
+
"last_used_at",
|
| 437 |
+
"last_fail_at",
|
| 438 |
+
"last_sync_at",
|
| 439 |
+
]:
|
| 440 |
+
if t_data.get(int_field) and t_data[int_field] != "None":
|
| 441 |
+
try:
|
| 442 |
+
t_data[int_field] = int(t_data[int_field])
|
| 443 |
+
except Exception:
|
| 444 |
+
pass
|
| 445 |
+
|
| 446 |
+
token_lookup[tid] = t_data
|
| 447 |
+
|
| 448 |
+
# 按 Pool 分组返回
|
| 449 |
+
for pool_name in pool_names:
|
| 450 |
+
pools[pool_name] = []
|
| 451 |
+
for tid in pool_map[pool_name]:
|
| 452 |
+
if tid in token_lookup:
|
| 453 |
+
pools[pool_name].append(token_lookup[tid])
|
| 454 |
+
|
| 455 |
+
return pools
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
logger.error(f"RedisStorage: 加载 Token 失败: {e}")
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 462 |
+
"""保存所有 Token"""
|
| 463 |
+
if data is None:
|
| 464 |
+
return
|
| 465 |
+
try:
|
| 466 |
+
new_pools = set(data.keys()) if isinstance(data, dict) else set()
|
| 467 |
+
pool_tokens_map = {}
|
| 468 |
+
new_token_ids = set()
|
| 469 |
+
|
| 470 |
+
for pool_name, tokens in (data or {}).items():
|
| 471 |
+
tids_in_pool = []
|
| 472 |
+
for t in tokens:
|
| 473 |
+
token_str = t.get("token")
|
| 474 |
+
if not token_str:
|
| 475 |
+
continue
|
| 476 |
+
tids_in_pool.append(token_str)
|
| 477 |
+
new_token_ids.add(token_str)
|
| 478 |
+
pool_tokens_map[pool_name] = tids_in_pool
|
| 479 |
+
|
| 480 |
+
existing_pools = await self.redis.smembers(self.key_pools)
|
| 481 |
+
existing_pools = set(existing_pools) if existing_pools else set()
|
| 482 |
+
|
| 483 |
+
existing_token_ids = set()
|
| 484 |
+
if existing_pools:
|
| 485 |
+
async with self.redis.pipeline() as pipe:
|
| 486 |
+
for pool_name in existing_pools:
|
| 487 |
+
pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
|
| 488 |
+
pool_tokens_res = await pipe.execute()
|
| 489 |
+
for tokens in pool_tokens_res:
|
| 490 |
+
existing_token_ids.update(list(tokens or []))
|
| 491 |
+
|
| 492 |
+
tokens_to_delete = existing_token_ids - new_token_ids
|
| 493 |
+
all_pools = existing_pools.union(new_pools)
|
| 494 |
+
|
| 495 |
+
async with self.redis.pipeline() as pipe:
|
| 496 |
+
# Reset pool index
|
| 497 |
+
pipe.delete(self.key_pools)
|
| 498 |
+
if new_pools:
|
| 499 |
+
pipe.sadd(self.key_pools, *new_pools)
|
| 500 |
+
|
| 501 |
+
# Reset pool sets
|
| 502 |
+
for pool_name in all_pools:
|
| 503 |
+
pipe.delete(f"{self.prefix_pool_set}{pool_name}")
|
| 504 |
+
for pool_name, tids_in_pool in pool_tokens_map.items():
|
| 505 |
+
if tids_in_pool:
|
| 506 |
+
pipe.sadd(f"{self.prefix_pool_set}{pool_name}", *tids_in_pool)
|
| 507 |
+
|
| 508 |
+
# Remove deleted token hashes
|
| 509 |
+
for token_str in tokens_to_delete:
|
| 510 |
+
pipe.delete(f"{self.prefix_token_hash}{token_str}")
|
| 511 |
+
|
| 512 |
+
# Upsert token hashes
|
| 513 |
+
for pool_name, tokens in (data or {}).items():
|
| 514 |
+
for t in tokens:
|
| 515 |
+
token_str = t.get("token")
|
| 516 |
+
if not token_str:
|
| 517 |
+
continue
|
| 518 |
+
t_flat = t.copy()
|
| 519 |
+
if "tags" in t_flat:
|
| 520 |
+
t_flat["tags"] = json_dumps(t_flat["tags"])
|
| 521 |
+
status = t_flat.get("status")
|
| 522 |
+
if isinstance(status, str) and status.startswith(
|
| 523 |
+
"TokenStatus."
|
| 524 |
+
):
|
| 525 |
+
t_flat["status"] = status.split(".", 1)[1].lower()
|
| 526 |
+
elif isinstance(status, Enum):
|
| 527 |
+
t_flat["status"] = status.value
|
| 528 |
+
t_flat = {k: str(v) for k, v in t_flat.items() if v is not None}
|
| 529 |
+
pipe.hset(
|
| 530 |
+
f"{self.prefix_token_hash}{token_str}", mapping=t_flat
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
await pipe.execute()
|
| 534 |
+
|
| 535 |
+
except Exception as e:
|
| 536 |
+
logger.error(f"RedisStorage: 保存 Token 失败: {e}")
|
| 537 |
+
raise
|
| 538 |
+
|
| 539 |
+
async def close(self):
|
| 540 |
+
try:
|
| 541 |
+
await self.redis.close()
|
| 542 |
+
except (RuntimeError, asyncio.CancelledError, Exception):
|
| 543 |
+
# 忽略关闭时的 Event loop is closed 错误
|
| 544 |
+
pass
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class SQLStorage(BaseStorage):
|
| 548 |
+
"""
|
| 549 |
+
SQL 数据库存储 (MySQL/PgSQL)
|
| 550 |
+
- 使用 SQLAlchemy 异步引擎
|
| 551 |
+
- 自动 Schema 初始化
|
| 552 |
+
- 内置连接池 (QueuePool)
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
def __init__(self, url: str, connect_args: dict | None = None):
|
| 556 |
+
try:
|
| 557 |
+
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
| 558 |
+
except ImportError:
|
| 559 |
+
raise ImportError(
|
| 560 |
+
"需要安装 sqlalchemy 和 async 驱动: pip install sqlalchemy[asyncio]"
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
self.dialect = url.split(":", 1)[0].split("+", 1)[0].lower()
|
| 564 |
+
|
| 565 |
+
# 配置 robust 的连接池
|
| 566 |
+
self.engine = create_async_engine(
|
| 567 |
+
url,
|
| 568 |
+
echo=False,
|
| 569 |
+
pool_size=20,
|
| 570 |
+
max_overflow=10,
|
| 571 |
+
pool_recycle=3600,
|
| 572 |
+
pool_pre_ping=True,
|
| 573 |
+
**({"connect_args": connect_args} if connect_args else {}),
|
| 574 |
+
)
|
| 575 |
+
self.async_session = async_sessionmaker(self.engine, expire_on_commit=False)
|
| 576 |
+
self._initialized = False
|
| 577 |
+
|
| 578 |
+
async def _ensure_schema(self):
|
| 579 |
+
"""确保数据库表存在"""
|
| 580 |
+
if self._initialized:
|
| 581 |
+
return
|
| 582 |
+
try:
|
| 583 |
+
async with self.engine.begin() as conn:
|
| 584 |
+
from sqlalchemy import text
|
| 585 |
+
|
| 586 |
+
# Tokens 表 (通用 SQL)
|
| 587 |
+
await conn.execute(
|
| 588 |
+
text("""
|
| 589 |
+
CREATE TABLE IF NOT EXISTS tokens (
|
| 590 |
+
token VARCHAR(512) PRIMARY KEY,
|
| 591 |
+
pool_name VARCHAR(64) NOT NULL,
|
| 592 |
+
status VARCHAR(16),
|
| 593 |
+
quota INT,
|
| 594 |
+
created_at BIGINT,
|
| 595 |
+
last_used_at BIGINT,
|
| 596 |
+
use_count INT,
|
| 597 |
+
fail_count INT,
|
| 598 |
+
last_fail_at BIGINT,
|
| 599 |
+
last_fail_reason TEXT,
|
| 600 |
+
last_sync_at BIGINT,
|
| 601 |
+
tags TEXT,
|
| 602 |
+
note TEXT,
|
| 603 |
+
last_asset_clear_at BIGINT,
|
| 604 |
+
data TEXT,
|
| 605 |
+
data_hash CHAR(64),
|
| 606 |
+
updated_at BIGINT
|
| 607 |
+
)
|
| 608 |
+
""")
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# 配置表
|
| 612 |
+
await conn.execute(
|
| 613 |
+
text("""
|
| 614 |
+
CREATE TABLE IF NOT EXISTS app_config (
|
| 615 |
+
section VARCHAR(64) NOT NULL,
|
| 616 |
+
key_name VARCHAR(64) NOT NULL,
|
| 617 |
+
value TEXT,
|
| 618 |
+
PRIMARY KEY (section, key_name)
|
| 619 |
+
)
|
| 620 |
+
""")
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# 索引
|
| 624 |
+
if self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 625 |
+
await conn.execute(
|
| 626 |
+
text(
|
| 627 |
+
"CREATE INDEX IF NOT EXISTS idx_tokens_pool ON tokens (pool_name)"
|
| 628 |
+
)
|
| 629 |
+
)
|
| 630 |
+
else:
|
| 631 |
+
try:
|
| 632 |
+
await conn.execute(
|
| 633 |
+
text("CREATE INDEX idx_tokens_pool ON tokens (pool_name)")
|
| 634 |
+
)
|
| 635 |
+
except Exception:
|
| 636 |
+
pass
|
| 637 |
+
|
| 638 |
+
# 补齐旧表字段
|
| 639 |
+
columns = [
|
| 640 |
+
("status", "VARCHAR(16)"),
|
| 641 |
+
("quota", "INT"),
|
| 642 |
+
("created_at", "BIGINT"),
|
| 643 |
+
("last_used_at", "BIGINT"),
|
| 644 |
+
("use_count", "INT"),
|
| 645 |
+
("fail_count", "INT"),
|
| 646 |
+
("last_fail_at", "BIGINT"),
|
| 647 |
+
("last_fail_reason", "TEXT"),
|
| 648 |
+
("last_sync_at", "BIGINT"),
|
| 649 |
+
("tags", "TEXT"),
|
| 650 |
+
("note", "TEXT"),
|
| 651 |
+
("last_asset_clear_at", "BIGINT"),
|
| 652 |
+
("data", "TEXT"),
|
| 653 |
+
("data_hash", "CHAR(64)"),
|
| 654 |
+
("updated_at", "BIGINT"),
|
| 655 |
+
]
|
| 656 |
+
if self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 657 |
+
for col_name, col_type in columns:
|
| 658 |
+
await conn.execute(
|
| 659 |
+
text(
|
| 660 |
+
f"ALTER TABLE tokens ADD COLUMN IF NOT EXISTS {col_name} {col_type}"
|
| 661 |
+
)
|
| 662 |
+
)
|
| 663 |
+
else:
|
| 664 |
+
for col_name, col_type in columns:
|
| 665 |
+
try:
|
| 666 |
+
await conn.execute(
|
| 667 |
+
text(
|
| 668 |
+
f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}"
|
| 669 |
+
)
|
| 670 |
+
)
|
| 671 |
+
except Exception:
|
| 672 |
+
pass
|
| 673 |
+
|
| 674 |
+
# 尝试兼容旧表结构
|
| 675 |
+
try:
|
| 676 |
+
if self.dialect in ("mysql", "mariadb"):
|
| 677 |
+
await conn.execute(
|
| 678 |
+
text("ALTER TABLE tokens MODIFY token VARCHAR(512)")
|
| 679 |
+
)
|
| 680 |
+
await conn.execute(text("ALTER TABLE tokens MODIFY data TEXT"))
|
| 681 |
+
elif self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 682 |
+
await conn.execute(
|
| 683 |
+
text(
|
| 684 |
+
"ALTER TABLE tokens ALTER COLUMN token TYPE VARCHAR(512)"
|
| 685 |
+
)
|
| 686 |
+
)
|
| 687 |
+
await conn.execute(
|
| 688 |
+
text("ALTER TABLE tokens ALTER COLUMN data TYPE TEXT")
|
| 689 |
+
)
|
| 690 |
+
except Exception:
|
| 691 |
+
pass
|
| 692 |
+
|
| 693 |
+
await self._migrate_legacy_tokens()
|
| 694 |
+
self._initialized = True
|
| 695 |
+
except Exception as e:
|
| 696 |
+
logger.error(f"SQLStorage: Schema 初始化失败: {e}")
|
| 697 |
+
raise
|
| 698 |
+
|
| 699 |
+
def _normalize_status(self, status: Any) -> Any:
|
| 700 |
+
if isinstance(status, str) and status.startswith("TokenStatus."):
|
| 701 |
+
return status.split(".", 1)[1].lower()
|
| 702 |
+
if isinstance(status, Enum):
|
| 703 |
+
return status.value
|
| 704 |
+
return status
|
| 705 |
+
|
| 706 |
+
def _normalize_tags(self, tags: Any) -> Optional[str]:
|
| 707 |
+
if tags is None:
|
| 708 |
+
return None
|
| 709 |
+
if isinstance(tags, str):
|
| 710 |
+
try:
|
| 711 |
+
parsed = json_loads(tags)
|
| 712 |
+
if isinstance(parsed, list):
|
| 713 |
+
return tags
|
| 714 |
+
except Exception:
|
| 715 |
+
pass
|
| 716 |
+
return json_dumps([tags])
|
| 717 |
+
return json_dumps(tags)
|
| 718 |
+
|
| 719 |
+
def _parse_tags(self, tags: Any) -> Optional[list]:
|
| 720 |
+
if tags is None:
|
| 721 |
+
return None
|
| 722 |
+
if isinstance(tags, str):
|
| 723 |
+
try:
|
| 724 |
+
parsed = json_loads(tags)
|
| 725 |
+
if isinstance(parsed, list):
|
| 726 |
+
return parsed
|
| 727 |
+
except Exception:
|
| 728 |
+
return []
|
| 729 |
+
if isinstance(tags, list):
|
| 730 |
+
return tags
|
| 731 |
+
return []
|
| 732 |
+
|
| 733 |
+
def _token_to_row(self, token_data: Dict[str, Any], pool_name: str) -> Dict[str, Any]:
|
| 734 |
+
token_str = token_data.get("token")
|
| 735 |
+
if isinstance(token_str, str) and token_str.startswith("sso="):
|
| 736 |
+
token_str = token_str[4:]
|
| 737 |
+
|
| 738 |
+
status = self._normalize_status(token_data.get("status"))
|
| 739 |
+
tags_json = self._normalize_tags(token_data.get("tags"))
|
| 740 |
+
data_json = json_dumps_sorted(token_data)
|
| 741 |
+
data_hash = hashlib.sha256(data_json.encode("utf-8")).hexdigest()
|
| 742 |
+
note = token_data.get("note")
|
| 743 |
+
if note is None:
|
| 744 |
+
note = ""
|
| 745 |
+
|
| 746 |
+
return {
|
| 747 |
+
"token": token_str,
|
| 748 |
+
"pool_name": pool_name,
|
| 749 |
+
"status": status,
|
| 750 |
+
"quota": token_data.get("quota"),
|
| 751 |
+
"created_at": token_data.get("created_at"),
|
| 752 |
+
"last_used_at": token_data.get("last_used_at"),
|
| 753 |
+
"use_count": token_data.get("use_count"),
|
| 754 |
+
"fail_count": token_data.get("fail_count"),
|
| 755 |
+
"last_fail_at": token_data.get("last_fail_at"),
|
| 756 |
+
"last_fail_reason": token_data.get("last_fail_reason"),
|
| 757 |
+
"last_sync_at": token_data.get("last_sync_at"),
|
| 758 |
+
"tags": tags_json,
|
| 759 |
+
"note": note,
|
| 760 |
+
"last_asset_clear_at": token_data.get("last_asset_clear_at"),
|
| 761 |
+
"data": data_json,
|
| 762 |
+
"data_hash": data_hash,
|
| 763 |
+
"updated_at": 0,
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
async def _migrate_legacy_tokens(self):
|
| 767 |
+
"""将旧版 data JSON 回填到平铺字段"""
|
| 768 |
+
from sqlalchemy import text
|
| 769 |
+
|
| 770 |
+
try:
|
| 771 |
+
async with self.async_session() as session:
|
| 772 |
+
try:
|
| 773 |
+
res = await session.execute(
|
| 774 |
+
text(
|
| 775 |
+
"SELECT token FROM tokens "
|
| 776 |
+
"WHERE data IS NOT NULL AND "
|
| 777 |
+
"(status IS NULL OR quota IS NULL OR created_at IS NULL) "
|
| 778 |
+
"LIMIT 1"
|
| 779 |
+
)
|
| 780 |
+
)
|
| 781 |
+
if not res.first():
|
| 782 |
+
return
|
| 783 |
+
except Exception as e:
|
| 784 |
+
msg = str(e).lower()
|
| 785 |
+
if "undefinedcolumn" in msg or "undefined column" in msg:
|
| 786 |
+
return
|
| 787 |
+
raise
|
| 788 |
+
|
| 789 |
+
res = await session.execute(
|
| 790 |
+
text(
|
| 791 |
+
"SELECT token, pool_name, data FROM tokens "
|
| 792 |
+
"WHERE data IS NOT NULL AND "
|
| 793 |
+
"(status IS NULL OR quota IS NULL OR created_at IS NULL)"
|
| 794 |
+
)
|
| 795 |
+
)
|
| 796 |
+
rows = res.fetchall()
|
| 797 |
+
if not rows:
|
| 798 |
+
return
|
| 799 |
+
|
| 800 |
+
params = []
|
| 801 |
+
for token_str, pool_name, data_json in rows:
|
| 802 |
+
if not data_json:
|
| 803 |
+
continue
|
| 804 |
+
try:
|
| 805 |
+
if isinstance(data_json, str):
|
| 806 |
+
t_data = json_loads(data_json)
|
| 807 |
+
else:
|
| 808 |
+
t_data = data_json
|
| 809 |
+
if not isinstance(t_data, dict):
|
| 810 |
+
continue
|
| 811 |
+
t_data = dict(t_data)
|
| 812 |
+
t_data["token"] = token_str
|
| 813 |
+
row = self._token_to_row(t_data, pool_name)
|
| 814 |
+
params.append(row)
|
| 815 |
+
except Exception:
|
| 816 |
+
continue
|
| 817 |
+
|
| 818 |
+
if not params:
|
| 819 |
+
return
|
| 820 |
+
|
| 821 |
+
await session.execute(
|
| 822 |
+
text(
|
| 823 |
+
"UPDATE tokens SET "
|
| 824 |
+
"pool_name=:pool_name, "
|
| 825 |
+
"status=:status, "
|
| 826 |
+
"quota=:quota, "
|
| 827 |
+
"created_at=:created_at, "
|
| 828 |
+
"last_used_at=:last_used_at, "
|
| 829 |
+
"use_count=:use_count, "
|
| 830 |
+
"fail_count=:fail_count, "
|
| 831 |
+
"last_fail_at=:last_fail_at, "
|
| 832 |
+
"last_fail_reason=:last_fail_reason, "
|
| 833 |
+
"last_sync_at=:last_sync_at, "
|
| 834 |
+
"tags=:tags, "
|
| 835 |
+
"note=:note, "
|
| 836 |
+
"last_asset_clear_at=:last_asset_clear_at, "
|
| 837 |
+
"data=:data, "
|
| 838 |
+
"data_hash=:data_hash, "
|
| 839 |
+
"updated_at=:updated_at "
|
| 840 |
+
"WHERE token=:token"
|
| 841 |
+
),
|
| 842 |
+
params,
|
| 843 |
+
)
|
| 844 |
+
await session.commit()
|
| 845 |
+
except Exception as e:
|
| 846 |
+
logger.warning(f"SQLStorage: 旧数据回填失败: {e}")
|
| 847 |
+
|
| 848 |
+
@asynccontextmanager
|
| 849 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 850 |
+
# SQL 分布式锁: MySQL GET_LOCK / PG advisory_lock
|
| 851 |
+
from sqlalchemy import text
|
| 852 |
+
|
| 853 |
+
lock_name = f"g2a:{hashlib.sha1(name.encode('utf-8')).hexdigest()[:24]}"
|
| 854 |
+
if self.dialect in ("mysql", "mariadb"):
|
| 855 |
+
async with self.async_session() as session:
|
| 856 |
+
res = await session.execute(
|
| 857 |
+
text("SELECT GET_LOCK(:name, :timeout)"),
|
| 858 |
+
{"name": lock_name, "timeout": timeout},
|
| 859 |
+
)
|
| 860 |
+
got = res.scalar()
|
| 861 |
+
if got != 1:
|
| 862 |
+
raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
|
| 863 |
+
try:
|
| 864 |
+
yield
|
| 865 |
+
finally:
|
| 866 |
+
try:
|
| 867 |
+
await session.execute(
|
| 868 |
+
text("SELECT RELEASE_LOCK(:name)"), {"name": lock_name}
|
| 869 |
+
)
|
| 870 |
+
await session.commit()
|
| 871 |
+
except Exception:
|
| 872 |
+
pass
|
| 873 |
+
elif self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 874 |
+
lock_key = int.from_bytes(
|
| 875 |
+
hashlib.sha256(name.encode("utf-8")).digest()[:8], "big", signed=True
|
| 876 |
+
)
|
| 877 |
+
async with self.async_session() as session:
|
| 878 |
+
start = time.monotonic()
|
| 879 |
+
while True:
|
| 880 |
+
res = await session.execute(
|
| 881 |
+
text("SELECT pg_try_advisory_lock(:key)"), {"key": lock_key}
|
| 882 |
+
)
|
| 883 |
+
if res.scalar():
|
| 884 |
+
break
|
| 885 |
+
if time.monotonic() - start >= timeout:
|
| 886 |
+
raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
|
| 887 |
+
await asyncio.sleep(0.1)
|
| 888 |
+
try:
|
| 889 |
+
yield
|
| 890 |
+
finally:
|
| 891 |
+
try:
|
| 892 |
+
await session.execute(
|
| 893 |
+
text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key}
|
| 894 |
+
)
|
| 895 |
+
await session.commit()
|
| 896 |
+
except Exception:
|
| 897 |
+
pass
|
| 898 |
+
else:
|
| 899 |
+
yield
|
| 900 |
+
|
| 901 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 902 |
+
await self._ensure_schema()
|
| 903 |
+
from sqlalchemy import text
|
| 904 |
+
|
| 905 |
+
try:
|
| 906 |
+
async with self.async_session() as session:
|
| 907 |
+
res = await session.execute(
|
| 908 |
+
text("SELECT section, key_name, value FROM app_config")
|
| 909 |
+
)
|
| 910 |
+
rows = res.fetchall()
|
| 911 |
+
if not rows:
|
| 912 |
+
return None
|
| 913 |
+
|
| 914 |
+
config = {}
|
| 915 |
+
for section, key, val_str in rows:
|
| 916 |
+
if section not in config:
|
| 917 |
+
config[section] = {}
|
| 918 |
+
try:
|
| 919 |
+
val = json_loads(val_str)
|
| 920 |
+
except Exception:
|
| 921 |
+
val = val_str
|
| 922 |
+
config[section][key] = val
|
| 923 |
+
return config
|
| 924 |
+
except Exception as e:
|
| 925 |
+
logger.error(f"SQLStorage: 加载配置失败: {e}")
|
| 926 |
+
return None
|
| 927 |
+
|
| 928 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 929 |
+
await self._ensure_schema()
|
| 930 |
+
from sqlalchemy import text
|
| 931 |
+
|
| 932 |
+
try:
|
| 933 |
+
async with self.async_session() as session:
|
| 934 |
+
await session.execute(text("DELETE FROM app_config"))
|
| 935 |
+
|
| 936 |
+
params = []
|
| 937 |
+
for section, items in data.items():
|
| 938 |
+
if not isinstance(items, dict):
|
| 939 |
+
continue
|
| 940 |
+
for key, val in items.items():
|
| 941 |
+
params.append(
|
| 942 |
+
{
|
| 943 |
+
"s": section,
|
| 944 |
+
"k": key,
|
| 945 |
+
"v": json_dumps(val),
|
| 946 |
+
}
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
if params:
|
| 950 |
+
await session.execute(
|
| 951 |
+
text(
|
| 952 |
+
"INSERT INTO app_config (section, key_name, value) VALUES (:s, :k, :v)"
|
| 953 |
+
),
|
| 954 |
+
params,
|
| 955 |
+
)
|
| 956 |
+
await session.commit()
|
| 957 |
+
except Exception as e:
|
| 958 |
+
logger.error(f"SQLStorage: 保存配置失败: {e}")
|
| 959 |
+
raise
|
| 960 |
+
|
| 961 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 962 |
+
await self._ensure_schema()
|
| 963 |
+
from sqlalchemy import text
|
| 964 |
+
|
| 965 |
+
try:
|
| 966 |
+
async with self.async_session() as session:
|
| 967 |
+
res = await session.execute(
|
| 968 |
+
text(
|
| 969 |
+
"SELECT token, pool_name, status, quota, created_at, "
|
| 970 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 971 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 972 |
+
"last_asset_clear_at, data "
|
| 973 |
+
"FROM tokens"
|
| 974 |
+
)
|
| 975 |
+
)
|
| 976 |
+
rows = res.fetchall()
|
| 977 |
+
if not rows:
|
| 978 |
+
return None
|
| 979 |
+
|
| 980 |
+
pools = {}
|
| 981 |
+
for (
|
| 982 |
+
token_str,
|
| 983 |
+
pool_name,
|
| 984 |
+
status,
|
| 985 |
+
quota,
|
| 986 |
+
created_at,
|
| 987 |
+
last_used_at,
|
| 988 |
+
use_count,
|
| 989 |
+
fail_count,
|
| 990 |
+
last_fail_at,
|
| 991 |
+
last_fail_reason,
|
| 992 |
+
last_sync_at,
|
| 993 |
+
tags,
|
| 994 |
+
note,
|
| 995 |
+
last_asset_clear_at,
|
| 996 |
+
data_json,
|
| 997 |
+
) in rows:
|
| 998 |
+
if pool_name not in pools:
|
| 999 |
+
pools[pool_name] = []
|
| 1000 |
+
|
| 1001 |
+
try:
|
| 1002 |
+
token_data = {}
|
| 1003 |
+
if token_str:
|
| 1004 |
+
token_data["token"] = token_str
|
| 1005 |
+
if status is not None:
|
| 1006 |
+
token_data["status"] = self._normalize_status(status)
|
| 1007 |
+
if quota is not None:
|
| 1008 |
+
token_data["quota"] = int(quota)
|
| 1009 |
+
if created_at is not None:
|
| 1010 |
+
token_data["created_at"] = int(created_at)
|
| 1011 |
+
if last_used_at is not None:
|
| 1012 |
+
token_data["last_used_at"] = int(last_used_at)
|
| 1013 |
+
if use_count is not None:
|
| 1014 |
+
token_data["use_count"] = int(use_count)
|
| 1015 |
+
if fail_count is not None:
|
| 1016 |
+
token_data["fail_count"] = int(fail_count)
|
| 1017 |
+
if last_fail_at is not None:
|
| 1018 |
+
token_data["last_fail_at"] = int(last_fail_at)
|
| 1019 |
+
if last_fail_reason is not None:
|
| 1020 |
+
token_data["last_fail_reason"] = last_fail_reason
|
| 1021 |
+
if last_sync_at is not None:
|
| 1022 |
+
token_data["last_sync_at"] = int(last_sync_at)
|
| 1023 |
+
if tags is not None:
|
| 1024 |
+
token_data["tags"] = self._parse_tags(tags)
|
| 1025 |
+
if note is not None:
|
| 1026 |
+
token_data["note"] = note
|
| 1027 |
+
if last_asset_clear_at is not None:
|
| 1028 |
+
token_data["last_asset_clear_at"] = int(
|
| 1029 |
+
last_asset_clear_at
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
legacy_data = None
|
| 1033 |
+
if data_json:
|
| 1034 |
+
if isinstance(data_json, str):
|
| 1035 |
+
legacy_data = json_loads(data_json)
|
| 1036 |
+
else:
|
| 1037 |
+
legacy_data = data_json
|
| 1038 |
+
if isinstance(legacy_data, dict):
|
| 1039 |
+
for key, val in legacy_data.items():
|
| 1040 |
+
if key not in token_data or token_data[key] is None:
|
| 1041 |
+
token_data[key] = val
|
| 1042 |
+
|
| 1043 |
+
pools[pool_name].append(token_data)
|
| 1044 |
+
except Exception:
|
| 1045 |
+
pass
|
| 1046 |
+
return pools
|
| 1047 |
+
except Exception as e:
|
| 1048 |
+
logger.error(f"SQLStorage: 加载 Token 失败: {e}")
|
| 1049 |
+
return None
|
| 1050 |
+
|
| 1051 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 1052 |
+
await self._ensure_schema()
|
| 1053 |
+
from sqlalchemy import text
|
| 1054 |
+
|
| 1055 |
+
if data is None:
|
| 1056 |
+
return
|
| 1057 |
+
|
| 1058 |
+
updates = []
|
| 1059 |
+
new_tokens = set()
|
| 1060 |
+
for pool_name, tokens in (data or {}).items():
|
| 1061 |
+
for t in tokens:
|
| 1062 |
+
if isinstance(t, dict):
|
| 1063 |
+
token_data = dict(t)
|
| 1064 |
+
elif isinstance(t, str):
|
| 1065 |
+
token_data = {"token": t}
|
| 1066 |
+
else:
|
| 1067 |
+
continue
|
| 1068 |
+
token_str = token_data.get("token")
|
| 1069 |
+
if not token_str:
|
| 1070 |
+
continue
|
| 1071 |
+
if token_str.startswith("sso="):
|
| 1072 |
+
token_str = token_str[4:]
|
| 1073 |
+
token_data["token"] = token_str
|
| 1074 |
+
token_data["pool_name"] = pool_name
|
| 1075 |
+
token_data["_update_kind"] = "state"
|
| 1076 |
+
updates.append(token_data)
|
| 1077 |
+
new_tokens.add(token_str)
|
| 1078 |
+
|
| 1079 |
+
try:
|
| 1080 |
+
existing_tokens = set()
|
| 1081 |
+
async with self.async_session() as session:
|
| 1082 |
+
res = await session.execute(text("SELECT token FROM tokens"))
|
| 1083 |
+
rows = res.fetchall()
|
| 1084 |
+
existing_tokens = {row[0] for row in rows}
|
| 1085 |
+
tokens_to_delete = list(existing_tokens - new_tokens)
|
| 1086 |
+
await self.save_tokens_delta(updates, tokens_to_delete)
|
| 1087 |
+
except Exception as e:
|
| 1088 |
+
logger.error(f"SQLStorage: 保存 Token 失败: {e}")
|
| 1089 |
+
raise
|
| 1090 |
+
|
| 1091 |
+
async def save_tokens_delta(
|
| 1092 |
+
self, updated: list[Dict[str, Any]], deleted: Optional[list[str]] = None
|
| 1093 |
+
):
|
| 1094 |
+
await self._ensure_schema()
|
| 1095 |
+
from sqlalchemy import bindparam, text
|
| 1096 |
+
|
| 1097 |
+
try:
|
| 1098 |
+
async with self.async_session() as session:
|
| 1099 |
+
deleted_set = set(deleted or [])
|
| 1100 |
+
if deleted_set:
|
| 1101 |
+
delete_stmt = text(
|
| 1102 |
+
"DELETE FROM tokens WHERE token IN :tokens"
|
| 1103 |
+
).bindparams(bindparam("tokens", expanding=True))
|
| 1104 |
+
chunk_size = 500
|
| 1105 |
+
deleted_list = list(deleted_set)
|
| 1106 |
+
for i in range(0, len(deleted_list), chunk_size):
|
| 1107 |
+
chunk = deleted_list[i : i + chunk_size]
|
| 1108 |
+
await session.execute(delete_stmt, {"tokens": chunk})
|
| 1109 |
+
|
| 1110 |
+
updates = []
|
| 1111 |
+
usage_updates = []
|
| 1112 |
+
|
| 1113 |
+
for item in updated or []:
|
| 1114 |
+
if not isinstance(item, dict):
|
| 1115 |
+
continue
|
| 1116 |
+
pool_name = item.get("pool_name")
|
| 1117 |
+
token_str = item.get("token")
|
| 1118 |
+
if not pool_name or not token_str:
|
| 1119 |
+
continue
|
| 1120 |
+
if token_str in deleted_set:
|
| 1121 |
+
continue
|
| 1122 |
+
update_kind = item.get("_update_kind", "state")
|
| 1123 |
+
token_data = {
|
| 1124 |
+
k: v
|
| 1125 |
+
for k, v in item.items()
|
| 1126 |
+
if k not in ("pool_name", "_update_kind")
|
| 1127 |
+
}
|
| 1128 |
+
row = self._token_to_row(token_data, pool_name)
|
| 1129 |
+
if update_kind == "usage":
|
| 1130 |
+
usage_updates.append(row)
|
| 1131 |
+
else:
|
| 1132 |
+
updates.append(row)
|
| 1133 |
+
|
| 1134 |
+
if updates:
|
| 1135 |
+
if self.dialect in ("mysql", "mariadb"):
|
| 1136 |
+
upsert_stmt = text(
|
| 1137 |
+
"INSERT INTO tokens (token, pool_name, status, quota, created_at, "
|
| 1138 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 1139 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 1140 |
+
"last_asset_clear_at, data, data_hash, updated_at) "
|
| 1141 |
+
"VALUES (:token, :pool_name, :status, :quota, :created_at, "
|
| 1142 |
+
":last_used_at, :use_count, :fail_count, :last_fail_at, "
|
| 1143 |
+
":last_fail_reason, :last_sync_at, :tags, :note, "
|
| 1144 |
+
":last_asset_clear_at, :data, :data_hash, :updated_at) "
|
| 1145 |
+
"ON DUPLICATE KEY UPDATE "
|
| 1146 |
+
"pool_name=VALUES(pool_name), "
|
| 1147 |
+
"status=VALUES(status), "
|
| 1148 |
+
"quota=VALUES(quota), "
|
| 1149 |
+
"created_at=VALUES(created_at), "
|
| 1150 |
+
"last_used_at=VALUES(last_used_at), "
|
| 1151 |
+
"use_count=VALUES(use_count), "
|
| 1152 |
+
"fail_count=VALUES(fail_count), "
|
| 1153 |
+
"last_fail_at=VALUES(last_fail_at), "
|
| 1154 |
+
"last_fail_reason=VALUES(last_fail_reason), "
|
| 1155 |
+
"last_sync_at=VALUES(last_sync_at), "
|
| 1156 |
+
"tags=VALUES(tags), "
|
| 1157 |
+
"note=VALUES(note), "
|
| 1158 |
+
"last_asset_clear_at=VALUES(last_asset_clear_at), "
|
| 1159 |
+
"data=VALUES(data), "
|
| 1160 |
+
"data_hash=VALUES(data_hash), "
|
| 1161 |
+
"updated_at=VALUES(updated_at)"
|
| 1162 |
+
)
|
| 1163 |
+
elif self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 1164 |
+
upsert_stmt = text(
|
| 1165 |
+
"INSERT INTO tokens (token, pool_name, status, quota, created_at, "
|
| 1166 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 1167 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 1168 |
+
"last_asset_clear_at, data, data_hash, updated_at) "
|
| 1169 |
+
"VALUES (:token, :pool_name, :status, :quota, :created_at, "
|
| 1170 |
+
":last_used_at, :use_count, :fail_count, :last_fail_at, "
|
| 1171 |
+
":last_fail_reason, :last_sync_at, :tags, :note, "
|
| 1172 |
+
":last_asset_clear_at, :data, :data_hash, :updated_at) "
|
| 1173 |
+
"ON CONFLICT (token) DO UPDATE SET "
|
| 1174 |
+
"pool_name=EXCLUDED.pool_name, "
|
| 1175 |
+
"status=EXCLUDED.status, "
|
| 1176 |
+
"quota=EXCLUDED.quota, "
|
| 1177 |
+
"created_at=EXCLUDED.created_at, "
|
| 1178 |
+
"last_used_at=EXCLUDED.last_used_at, "
|
| 1179 |
+
"use_count=EXCLUDED.use_count, "
|
| 1180 |
+
"fail_count=EXCLUDED.fail_count, "
|
| 1181 |
+
"last_fail_at=EXCLUDED.last_fail_at, "
|
| 1182 |
+
"last_fail_reason=EXCLUDED.last_fail_reason, "
|
| 1183 |
+
"last_sync_at=EXCLUDED.last_sync_at, "
|
| 1184 |
+
"tags=EXCLUDED.tags, "
|
| 1185 |
+
"note=EXCLUDED.note, "
|
| 1186 |
+
"last_asset_clear_at=EXCLUDED.last_asset_clear_at, "
|
| 1187 |
+
"data=EXCLUDED.data, "
|
| 1188 |
+
"data_hash=EXCLUDED.data_hash, "
|
| 1189 |
+
"updated_at=EXCLUDED.updated_at"
|
| 1190 |
+
)
|
| 1191 |
+
else:
|
| 1192 |
+
upsert_stmt = text(
|
| 1193 |
+
"INSERT INTO tokens (token, pool_name, status, quota, created_at, "
|
| 1194 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 1195 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 1196 |
+
"last_asset_clear_at, data, data_hash, updated_at) "
|
| 1197 |
+
"VALUES (:token, :pool_name, :status, :quota, :created_at, "
|
| 1198 |
+
":last_used_at, :use_count, :fail_count, :last_fail_at, "
|
| 1199 |
+
":last_fail_reason, :last_sync_at, :tags, :note, "
|
| 1200 |
+
":last_asset_clear_at, :data, :data_hash, :updated_at)"
|
| 1201 |
+
)
|
| 1202 |
+
await session.execute(upsert_stmt, updates)
|
| 1203 |
+
|
| 1204 |
+
if usage_updates:
|
| 1205 |
+
if self.dialect in ("mysql", "mariadb"):
|
| 1206 |
+
usage_stmt = text(
|
| 1207 |
+
"INSERT INTO tokens (token, pool_name, status, quota, created_at, "
|
| 1208 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 1209 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 1210 |
+
"last_asset_clear_at, data, data_hash, updated_at) "
|
| 1211 |
+
"VALUES (:token, :pool_name, :status, :quota, :created_at, "
|
| 1212 |
+
":last_used_at, :use_count, :fail_count, :last_fail_at, "
|
| 1213 |
+
":last_fail_reason, :last_sync_at, :tags, :note, "
|
| 1214 |
+
":last_asset_clear_at, :data, :data_hash, :updated_at) "
|
| 1215 |
+
"ON DUPLICATE KEY UPDATE "
|
| 1216 |
+
"pool_name=VALUES(pool_name), "
|
| 1217 |
+
"status=VALUES(status), "
|
| 1218 |
+
"quota=VALUES(quota), "
|
| 1219 |
+
"last_used_at=VALUES(last_used_at), "
|
| 1220 |
+
"use_count=VALUES(use_count), "
|
| 1221 |
+
"fail_count=VALUES(fail_count), "
|
| 1222 |
+
"last_fail_at=VALUES(last_fail_at), "
|
| 1223 |
+
"last_fail_reason=VALUES(last_fail_reason), "
|
| 1224 |
+
"last_sync_at=VALUES(last_sync_at), "
|
| 1225 |
+
"updated_at=VALUES(updated_at)"
|
| 1226 |
+
)
|
| 1227 |
+
elif self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 1228 |
+
usage_stmt = text(
|
| 1229 |
+
"INSERT INTO tokens (token, pool_name, status, quota, created_at, "
|
| 1230 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 1231 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 1232 |
+
"last_asset_clear_at, data, data_hash, updated_at) "
|
| 1233 |
+
"VALUES (:token, :pool_name, :status, :quota, :created_at, "
|
| 1234 |
+
":last_used_at, :use_count, :fail_count, :last_fail_at, "
|
| 1235 |
+
":last_fail_reason, :last_sync_at, :tags, :note, "
|
| 1236 |
+
":last_asset_clear_at, :data, :data_hash, :updated_at) "
|
| 1237 |
+
"ON CONFLICT (token) DO UPDATE SET "
|
| 1238 |
+
"pool_name=EXCLUDED.pool_name, "
|
| 1239 |
+
"status=EXCLUDED.status, "
|
| 1240 |
+
"quota=EXCLUDED.quota, "
|
| 1241 |
+
"last_used_at=EXCLUDED.last_used_at, "
|
| 1242 |
+
"use_count=EXCLUDED.use_count, "
|
| 1243 |
+
"fail_count=EXCLUDED.fail_count, "
|
| 1244 |
+
"last_fail_at=EXCLUDED.last_fail_at, "
|
| 1245 |
+
"last_fail_reason=EXCLUDED.last_fail_reason, "
|
| 1246 |
+
"last_sync_at=EXCLUDED.last_sync_at, "
|
| 1247 |
+
"updated_at=EXCLUDED.updated_at"
|
| 1248 |
+
)
|
| 1249 |
+
else:
|
| 1250 |
+
usage_stmt = text(
|
| 1251 |
+
"INSERT INTO tokens (token, pool_name, status, quota, created_at, "
|
| 1252 |
+
"last_used_at, use_count, fail_count, last_fail_at, "
|
| 1253 |
+
"last_fail_reason, last_sync_at, tags, note, "
|
| 1254 |
+
"last_asset_clear_at, data, data_hash, updated_at) "
|
| 1255 |
+
"VALUES (:token, :pool_name, :status, :quota, :created_at, "
|
| 1256 |
+
":last_used_at, :use_count, :fail_count, :last_fail_at, "
|
| 1257 |
+
":last_fail_reason, :last_sync_at, :tags, :note, "
|
| 1258 |
+
":last_asset_clear_at, :data, :data_hash, :updated_at)"
|
| 1259 |
+
)
|
| 1260 |
+
await session.execute(usage_stmt, usage_updates)
|
| 1261 |
+
|
| 1262 |
+
await session.commit()
|
| 1263 |
+
except Exception as e:
|
| 1264 |
+
logger.error(f"SQLStorage: 增量保存 Token 失败: {e}")
|
| 1265 |
+
raise
|
| 1266 |
+
|
| 1267 |
+
async def close(self):
|
| 1268 |
+
await self.engine.dispose()
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
class StorageFactory:
|
| 1272 |
+
"""存储后端工厂"""
|
| 1273 |
+
|
| 1274 |
+
_instance: Optional[BaseStorage] = None
|
| 1275 |
+
|
| 1276 |
+
# SSL-related query parameters that async drivers (asyncpg, aiomysql)
|
| 1277 |
+
# cannot accept via the URL and must be passed as connect_args instead.
|
| 1278 |
+
_SQL_SSL_PARAM_KEYS = ("sslmode", "ssl-mode", "ssl")
|
| 1279 |
+
|
| 1280 |
+
# Canonical postgres ssl modes (asyncpg accepts libpq-style mode strings).
|
| 1281 |
+
_PG_SSL_MODE_ALIASES: ClassVar[dict[str, str]] = {
|
| 1282 |
+
"disable": "disable",
|
| 1283 |
+
"disabled": "disable",
|
| 1284 |
+
"false": "disable",
|
| 1285 |
+
"0": "disable",
|
| 1286 |
+
"no": "disable",
|
| 1287 |
+
"off": "disable",
|
| 1288 |
+
"prefer": "prefer",
|
| 1289 |
+
"preferred": "prefer",
|
| 1290 |
+
"allow": "allow",
|
| 1291 |
+
"require": "require",
|
| 1292 |
+
"required": "require",
|
| 1293 |
+
"true": "require",
|
| 1294 |
+
"1": "require",
|
| 1295 |
+
"yes": "require",
|
| 1296 |
+
"on": "require",
|
| 1297 |
+
"verify-ca": "verify-ca",
|
| 1298 |
+
"verify_ca": "verify-ca",
|
| 1299 |
+
"verify-full": "verify-full",
|
| 1300 |
+
"verify_full": "verify-full",
|
| 1301 |
+
"verify-identity": "verify-full",
|
| 1302 |
+
"verify_identity": "verify-full",
|
| 1303 |
+
}
|
| 1304 |
+
|
| 1305 |
+
# Canonical mysql ssl modes (aiomysql accepts SSLContext, not mode strings).
|
| 1306 |
+
_MY_SSL_MODE_ALIASES: ClassVar[dict[str, str]] = {
|
| 1307 |
+
"disable": "disabled",
|
| 1308 |
+
"disabled": "disabled",
|
| 1309 |
+
"false": "disabled",
|
| 1310 |
+
"0": "disabled",
|
| 1311 |
+
"no": "disabled",
|
| 1312 |
+
"off": "disabled",
|
| 1313 |
+
"prefer": "preferred",
|
| 1314 |
+
"preferred": "preferred",
|
| 1315 |
+
"allow": "preferred",
|
| 1316 |
+
"require": "required",
|
| 1317 |
+
"required": "required",
|
| 1318 |
+
"true": "required",
|
| 1319 |
+
"1": "required",
|
| 1320 |
+
"yes": "required",
|
| 1321 |
+
"on": "required",
|
| 1322 |
+
"verify-ca": "verify_ca",
|
| 1323 |
+
"verify_ca": "verify_ca",
|
| 1324 |
+
"verify-full": "verify_identity",
|
| 1325 |
+
"verify_full": "verify_identity",
|
| 1326 |
+
"verify-identity": "verify_identity",
|
| 1327 |
+
"verify_identity": "verify_identity",
|
| 1328 |
+
}
|
| 1329 |
+
|
| 1330 |
+
@classmethod
|
| 1331 |
+
def _normalize_ssl_mode(cls, storage_type: str, mode: str) -> str:
|
| 1332 |
+
"""Normalize SSL mode aliases for the target storage backend."""
|
| 1333 |
+
if not mode:
|
| 1334 |
+
raise ValueError("SSL mode cannot be empty")
|
| 1335 |
+
|
| 1336 |
+
normalized = mode.strip().lower().replace(" ", "")
|
| 1337 |
+
if storage_type == "pgsql":
|
| 1338 |
+
canonical = cls._PG_SSL_MODE_ALIASES.get(normalized)
|
| 1339 |
+
elif storage_type == "mysql":
|
| 1340 |
+
canonical = cls._MY_SSL_MODE_ALIASES.get(normalized)
|
| 1341 |
+
else:
|
| 1342 |
+
canonical = None
|
| 1343 |
+
|
| 1344 |
+
if not canonical:
|
| 1345 |
+
raise ValueError(
|
| 1346 |
+
f"Unsupported SSL mode '{mode}' for storage type '{storage_type}'"
|
| 1347 |
+
)
|
| 1348 |
+
return canonical
|
| 1349 |
+
|
| 1350 |
+
@classmethod
|
| 1351 |
+
def _build_mysql_ssl_context(cls, mode: str):
|
| 1352 |
+
"""Build SSLContext for aiomysql according to normalized mysql mode.
|
| 1353 |
+
|
| 1354 |
+
Note: aiomysql enforces SSL whenever an SSLContext is provided — there
|
| 1355 |
+
is no "try SSL, fall back to plaintext" behaviour. As a result the
|
| 1356 |
+
``preferred`` mode is treated identically to ``required`` (encrypted,
|
| 1357 |
+
no cert verification). Connections to MySQL servers that do not
|
| 1358 |
+
support SSL will fail rather than degrade gracefully.
|
| 1359 |
+
"""
|
| 1360 |
+
import ssl as _ssl
|
| 1361 |
+
|
| 1362 |
+
if mode == "disabled":
|
| 1363 |
+
return None
|
| 1364 |
+
|
| 1365 |
+
ctx = _ssl.create_default_context()
|
| 1366 |
+
if mode in ("preferred", "required"):
|
| 1367 |
+
ctx.check_hostname = False
|
| 1368 |
+
ctx.verify_mode = _ssl.CERT_NONE
|
| 1369 |
+
elif mode == "verify_ca":
|
| 1370 |
+
# verify CA, but do not enforce hostname match.
|
| 1371 |
+
ctx.check_hostname = False
|
| 1372 |
+
# verify_identity keeps defaults: verify cert + hostname.
|
| 1373 |
+
return ctx
|
| 1374 |
+
|
| 1375 |
+
@classmethod
|
| 1376 |
+
def _build_sql_connect_args(
|
| 1377 |
+
cls, storage_type: str, raw_ssl_mode: Optional[str]
|
| 1378 |
+
) -> Optional[dict]:
|
| 1379 |
+
"""Build SQLAlchemy connect_args for SQL SSL modes."""
|
| 1380 |
+
if not raw_ssl_mode:
|
| 1381 |
+
return None
|
| 1382 |
+
|
| 1383 |
+
mode = cls._normalize_ssl_mode(storage_type, raw_ssl_mode)
|
| 1384 |
+
if storage_type == "pgsql":
|
| 1385 |
+
# asyncpg accepts libpq-style ssl mode strings via ssl=...
|
| 1386 |
+
return {"ssl": mode}
|
| 1387 |
+
if storage_type == "mysql":
|
| 1388 |
+
ctx = cls._build_mysql_ssl_context(mode)
|
| 1389 |
+
if ctx is None:
|
| 1390 |
+
return None
|
| 1391 |
+
return {"ssl": ctx}
|
| 1392 |
+
return None
|
| 1393 |
+
|
| 1394 |
+
@classmethod
|
| 1395 |
+
def _normalize_sql_url(cls, storage_type: str, url: str) -> str:
|
| 1396 |
+
"""Rewrite scheme prefix to the SQLAlchemy async dialect form."""
|
| 1397 |
+
if not url or "://" not in url:
|
| 1398 |
+
return url
|
| 1399 |
+
if storage_type == "mysql":
|
| 1400 |
+
if url.startswith("mysql://"):
|
| 1401 |
+
url = f"mysql+aiomysql://{url[len('mysql://') :]}"
|
| 1402 |
+
elif url.startswith("mariadb://"):
|
| 1403 |
+
# Use mysql+aiomysql for both MySQL and MariaDB endpoints.
|
| 1404 |
+
# The mariadb dialect enforces strict MariaDB server detection.
|
| 1405 |
+
url = f"mysql+aiomysql://{url[len('mariadb://') :]}"
|
| 1406 |
+
elif url.startswith("mariadb+aiomysql://"):
|
| 1407 |
+
url = f"mysql+aiomysql://{url[len('mariadb+aiomysql://') :]}"
|
| 1408 |
+
elif storage_type == "pgsql":
|
| 1409 |
+
if url.startswith("postgres://"):
|
| 1410 |
+
url = f"postgresql+asyncpg://{url[len('postgres://') :]}"
|
| 1411 |
+
elif url.startswith("postgresql://"):
|
| 1412 |
+
url = f"postgresql+asyncpg://{url[len('postgresql://') :]}"
|
| 1413 |
+
elif url.startswith("pgsql://"):
|
| 1414 |
+
url = f"postgresql+asyncpg://{url[len('pgsql://') :]}"
|
| 1415 |
+
return url
|
| 1416 |
+
|
| 1417 |
+
@classmethod
|
| 1418 |
+
def _prepare_sql_url_and_connect_args(
|
| 1419 |
+
cls, storage_type: str, url: str
|
| 1420 |
+
) -> tuple[str, Optional[dict]]:
|
| 1421 |
+
"""Normalize SQL URL and build connect_args from SSL query params."""
|
| 1422 |
+
from urllib.parse import urlparse, parse_qsl, urlencode, urlunparse
|
| 1423 |
+
|
| 1424 |
+
normalized_url = cls._normalize_sql_url(storage_type, url)
|
| 1425 |
+
if "://" not in normalized_url:
|
| 1426 |
+
return normalized_url, None
|
| 1427 |
+
|
| 1428 |
+
parsed = urlparse(normalized_url)
|
| 1429 |
+
ssl_mode: Optional[str] = None
|
| 1430 |
+
filtered_query_items = []
|
| 1431 |
+
ssl_param_keys = {k.lower() for k in cls._SQL_SSL_PARAM_KEYS}
|
| 1432 |
+
for key, value in parse_qsl(parsed.query, keep_blank_values=True):
|
| 1433 |
+
if key.lower() in ssl_param_keys:
|
| 1434 |
+
if ssl_mode is None and value:
|
| 1435 |
+
ssl_mode = value
|
| 1436 |
+
continue
|
| 1437 |
+
filtered_query_items.append((key, value))
|
| 1438 |
+
|
| 1439 |
+
cleaned_url = urlunparse(
|
| 1440 |
+
parsed._replace(query=urlencode(filtered_query_items, doseq=True))
|
| 1441 |
+
)
|
| 1442 |
+
connect_args = cls._build_sql_connect_args(storage_type, ssl_mode)
|
| 1443 |
+
return cleaned_url, connect_args
|
| 1444 |
+
|
| 1445 |
+
@classmethod
|
| 1446 |
+
def get_storage(cls) -> BaseStorage:
|
| 1447 |
+
"""获取全局存储实例 (单例)"""
|
| 1448 |
+
if cls._instance:
|
| 1449 |
+
return cls._instance
|
| 1450 |
+
|
| 1451 |
+
storage_type = os.getenv("SERVER_STORAGE_TYPE", "local").lower()
|
| 1452 |
+
storage_url = os.getenv("SERVER_STORAGE_URL", "")
|
| 1453 |
+
|
| 1454 |
+
logger.info(f"StorageFactory: 初始化存储后端: {storage_type}")
|
| 1455 |
+
|
| 1456 |
+
if storage_type == "redis":
|
| 1457 |
+
if not storage_url:
|
| 1458 |
+
raise ValueError("Redis 存储需要设置 SERVER_STORAGE_URL")
|
| 1459 |
+
cls._instance = RedisStorage(storage_url)
|
| 1460 |
+
|
| 1461 |
+
elif storage_type in ("mysql", "pgsql"):
|
| 1462 |
+
if not storage_url:
|
| 1463 |
+
raise ValueError("SQL 存储需要设置 SERVER_STORAGE_URL")
|
| 1464 |
+
# Drivers reject SSL query params in URL. Normalize URL and pass
|
| 1465 |
+
# backend-specific SSL handling through connect_args.
|
| 1466 |
+
storage_url, connect_args = cls._prepare_sql_url_and_connect_args(
|
| 1467 |
+
storage_type, storage_url
|
| 1468 |
+
)
|
| 1469 |
+
cls._instance = SQLStorage(storage_url, connect_args=connect_args)
|
| 1470 |
+
|
| 1471 |
+
else:
|
| 1472 |
+
cls._instance = LocalStorage()
|
| 1473 |
+
|
| 1474 |
+
return cls._instance
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
def get_storage() -> BaseStorage:
|
| 1478 |
+
return StorageFactory.get_storage()
|
app/services/cf_refresh/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cf_refresh - Cloudflare cf_clearance 自动刷新
|
| 2 |
+
|
| 3 |
+
通过 [FlareSolverr](https://github.com/FlareSolverr/FlareSolverr) 自动获取 Cloudflare `cf_clearance` cookie 和 `user_agent`,并更新到 Grok2API 服务配置中。
|
| 4 |
+
|
| 5 |
+
全自动、无需 GUI、服务器友好。
|
| 6 |
+
|
| 7 |
+
## 工作原理
|
| 8 |
+
|
| 9 |
+
1. FlareSolverr(独立 Docker 容器)内部运行 Chrome,自动通过 CF 挑战
|
| 10 |
+
2. cf_refresh 作为 grok2api 的后台任务,调用 FlareSolverr HTTP API 获取 `cf_clearance` 和 `user_agent`
|
| 11 |
+
3. 直接在进程内调用 `config.update()` 更新运行时配置并持久化到 `data/config.toml`
|
| 12 |
+
4. 按设定间隔重复以上步骤
|
| 13 |
+
|
| 14 |
+
## 配置方式
|
| 15 |
+
|
| 16 |
+
所有配置均可在管理面板 `/admin/config` 的 **CF 自动刷新** 区域中设置,也可通过环境变量初始化:
|
| 17 |
+
|
| 18 |
+
| 配置项 | 环境变量 | 默认值 | 说明 |
|
| 19 |
+
|--------|----------|--------|------|
|
| 20 |
+
| 启用自动刷新 | `FLARESOLVERR_URL`(非空即启用) | `false` | 是否开启自动刷新 |
|
| 21 |
+
| FlareSolverr 地址 | `FLARESOLVERR_URL` | — | FlareSolverr 服务的 HTTP 地址 |
|
| 22 |
+
| 刷新间隔(秒) | `CF_REFRESH_INTERVAL` | `600` | 定期刷新间隔 |
|
| 23 |
+
| 挑战超时(秒) | `CF_TIMEOUT` | `60` | CF 挑战等待超时 |
|
| 24 |
+
|
| 25 |
+
> **代理**:自动使用「代理配置 → 基础代理 URL」,无需单独设置,保证出口 IP 一致。
|
| 26 |
+
|
| 27 |
+
## 使用方式
|
| 28 |
+
|
| 29 |
+
### Docker Compose 部署
|
| 30 |
+
|
| 31 |
+
已集成在项目根目录 `docker-compose.yml` 中。只需在 grok2api 服务的环境变量中设置 `FLARESOLVERR_URL`,并添加 `flaresolverr` 服务即可:
|
| 32 |
+
|
| 33 |
+
```yaml
|
| 34 |
+
services:
|
| 35 |
+
grok2api:
|
| 36 |
+
environment:
|
| 37 |
+
FLARESOLVERR_URL: http://flaresolverr:8191
|
| 38 |
+
|
| 39 |
+
flaresolverr:
|
| 40 |
+
image: ghcr.io/flaresolverr/flaresolverr:latest
|
| 41 |
+
restart: unless-stopped
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## 注意事项
|
| 45 |
+
|
| 46 |
+
- `cf_clearance` 与请求来源 IP 绑定,FlareSolverr 自动使用代理配置中的基础代理 URL 保证出口 IP 一致
|
| 47 |
+
- 启用自动刷新后,代理配置中的 CF Clearance、浏览器指纹和 User-Agent 由系统自动管理(面板中变灰)
|
| 48 |
+
- 建议刷新间隔不低于 5 分钟,避免触发 Cloudflare 频率限制
|
| 49 |
+
- FlareSolverr 需要约 500MB 内存(内部运行 Chrome)
|
app/services/cf_refresh/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""cf_refresh - Cloudflare cf_clearance 自动刷新模块"""
|
| 2 |
+
|
| 3 |
+
from .scheduler import start, stop
|
| 4 |
+
|
| 5 |
+
__all__ = ["start", "stop"]
|
app/services/cf_refresh/config.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""配置管理 — 从 app config 的 proxy.* 读取,支持面板修改实时生效"""
|
| 2 |
+
|
| 3 |
+
GROK_URL = "https://grok.com"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get(key: str, default=None):
|
| 7 |
+
"""从 app config 读取 proxy.* 配置"""
|
| 8 |
+
from app.core.config import get_config
|
| 9 |
+
return get_config(f"proxy.{key}", default)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_flaresolverr_url() -> str:
|
| 13 |
+
return _get("flaresolverr_url", "") or ""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _get_int(key: str, default: int, min_value: int) -> int:
|
| 17 |
+
raw = _get(key, default)
|
| 18 |
+
try:
|
| 19 |
+
value = int(raw)
|
| 20 |
+
except (TypeError, ValueError):
|
| 21 |
+
return max(default, min_value)
|
| 22 |
+
if value < min_value:
|
| 23 |
+
return min_value
|
| 24 |
+
return value
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_refresh_interval() -> int:
|
| 28 |
+
return _get_int("refresh_interval", 600, 60)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_timeout() -> int:
|
| 32 |
+
return _get_int("timeout", 60, 60)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_proxy() -> str:
|
| 36 |
+
"""使用基础代理 URL,保证出口 IP 一致"""
|
| 37 |
+
return _get("base_proxy_url", "") or ""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def is_enabled() -> bool:
|
| 41 |
+
return bool(_get("enabled", False))
|
app/services/cf_refresh/scheduler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""定时调度:周期性刷新 cf_clearance(集成到 grok2api 进程内)"""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
|
| 7 |
+
from .config import get_refresh_interval, get_flaresolverr_url, is_enabled
|
| 8 |
+
from .solver import solve_cf_challenge
|
| 9 |
+
|
| 10 |
+
_task: asyncio.Task | None = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
async def _update_app_config(
|
| 14 |
+
cf_cookies: str,
|
| 15 |
+
user_agent: str = "",
|
| 16 |
+
browser: str = "",
|
| 17 |
+
cf_clearance: str = "",
|
| 18 |
+
) -> bool:
|
| 19 |
+
"""直接更新 grok2api 的运行时配置"""
|
| 20 |
+
try:
|
| 21 |
+
from app.core.config import config
|
| 22 |
+
|
| 23 |
+
proxy_update = {"cf_cookies": cf_cookies}
|
| 24 |
+
if cf_clearance:
|
| 25 |
+
proxy_update["cf_clearance"] = cf_clearance
|
| 26 |
+
if user_agent:
|
| 27 |
+
proxy_update["user_agent"] = user_agent
|
| 28 |
+
if browser:
|
| 29 |
+
proxy_update["browser"] = browser
|
| 30 |
+
|
| 31 |
+
await config.update({"proxy": proxy_update})
|
| 32 |
+
|
| 33 |
+
logger.info(f"配置已更新: cf_cookies (长度 {len(cf_cookies)}), 指纹: {browser}")
|
| 34 |
+
if user_agent:
|
| 35 |
+
logger.info(f"配置已更新: user_agent = {user_agent}")
|
| 36 |
+
return True
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"更新配置失败: {e}")
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
async def refresh_once() -> bool:
|
| 43 |
+
"""执行一次刷新流程"""
|
| 44 |
+
logger.info("=" * 50)
|
| 45 |
+
logger.info("开始刷新 cf_clearance...")
|
| 46 |
+
|
| 47 |
+
result = await solve_cf_challenge()
|
| 48 |
+
if not result:
|
| 49 |
+
logger.error("刷新失败:无法获取 cf_clearance")
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
success = await _update_app_config(
|
| 53 |
+
cf_cookies=result["cookies"],
|
| 54 |
+
cf_clearance=result.get("cf_clearance", ""),
|
| 55 |
+
user_agent=result.get("user_agent", ""),
|
| 56 |
+
browser=result.get("browser", ""),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if success:
|
| 60 |
+
logger.info("刷新完成")
|
| 61 |
+
else:
|
| 62 |
+
logger.error("刷新失败: 更新配置失败")
|
| 63 |
+
|
| 64 |
+
return success
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
async def _scheduler_loop():
|
| 68 |
+
"""后台调度循环"""
|
| 69 |
+
logger.info(
|
| 70 |
+
f"cf_refresh scheduler started (FlareSolverr: {get_flaresolverr_url()}, interval: {get_refresh_interval()}s)"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# 周期性刷新(每次循环重新读取配置,支持面板修改实时生效)
|
| 74 |
+
while True:
|
| 75 |
+
if is_enabled():
|
| 76 |
+
await refresh_once()
|
| 77 |
+
else:
|
| 78 |
+
logger.debug("cf_refresh disabled, skip refresh")
|
| 79 |
+
interval = get_refresh_interval()
|
| 80 |
+
await asyncio.sleep(interval)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def start():
|
| 84 |
+
"""启动后台刷新任务"""
|
| 85 |
+
global _task
|
| 86 |
+
if _task is not None:
|
| 87 |
+
return
|
| 88 |
+
_task = asyncio.get_event_loop().create_task(_scheduler_loop())
|
| 89 |
+
logger.info("cf_refresh background task started")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def stop():
|
| 93 |
+
"""停止后台刷新任务"""
|
| 94 |
+
global _task
|
| 95 |
+
if _task is not None:
|
| 96 |
+
_task.cancel()
|
| 97 |
+
_task = None
|
| 98 |
+
logger.info("cf_refresh background task stopped")
|
app/services/cf_refresh/solver.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
通过 FlareSolverr 自动获取 cf_clearance
|
| 3 |
+
|
| 4 |
+
FlareSolverr 是一个 Docker 服务,内部运行 Chrome 浏览器,
|
| 5 |
+
自动处理 Cloudflare 挑战(包括 Turnstile),无需 GUI。
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import json
|
| 10 |
+
from typing import Optional, Dict
|
| 11 |
+
from urllib import request as urllib_request
|
| 12 |
+
from urllib.error import HTTPError, URLError
|
| 13 |
+
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
from .config import GROK_URL, get_timeout, get_proxy, get_flaresolverr_url
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _extract_all_cookies(cookies: list[dict]) -> str:
|
| 20 |
+
"""将 FlareSolverr 返回 of cookie 列表转换为字符串格式"""
|
| 21 |
+
return "; ".join([f"{c.get('name')}={c.get('value')}" for c in cookies])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _extract_cookie_value(cookies: list[dict], name: str) -> str:
|
| 25 |
+
for cookie in cookies:
|
| 26 |
+
if cookie.get("name") == name:
|
| 27 |
+
return cookie.get("value") or ""
|
| 28 |
+
return ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _extract_user_agent(solution: dict) -> str:
|
| 32 |
+
"""从 FlareSolverr 的 solution 中提取 User-Agent"""
|
| 33 |
+
return solution.get("userAgent", "")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _extract_browser_profile(user_agent: str) -> str:
|
| 37 |
+
"""从 User-Agent 提取 chromeXXX 格式的指纹识别号"""
|
| 38 |
+
import re
|
| 39 |
+
match = re.search(r"Chrome/(\d+)", user_agent)
|
| 40 |
+
if match:
|
| 41 |
+
return f"chrome{match.group(1)}"
|
| 42 |
+
return "chrome120"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def solve_cf_challenge() -> Optional[Dict[str, str]]:
|
| 46 |
+
"""
|
| 47 |
+
通过 FlareSolverr 访问 grok.com,自动过 CF 挑战,提取 cf_clearance。
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
成功时返回 {"cookies": "...", "user_agent": "..."},失败返回 None
|
| 51 |
+
"""
|
| 52 |
+
flaresolverr_url = get_flaresolverr_url()
|
| 53 |
+
cf_timeout = get_timeout()
|
| 54 |
+
proxy = get_proxy()
|
| 55 |
+
|
| 56 |
+
if not flaresolverr_url:
|
| 57 |
+
logger.error("FlareSolverr 地址未配置,无法刷新 cf_clearance")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
url = f"{flaresolverr_url.rstrip('/')}/v1"
|
| 61 |
+
|
| 62 |
+
payload = {
|
| 63 |
+
"cmd": "request.get",
|
| 64 |
+
"url": GROK_URL,
|
| 65 |
+
"maxTimeout": cf_timeout * 1000,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
if proxy:
|
| 69 |
+
payload["proxy"] = {"url": proxy}
|
| 70 |
+
|
| 71 |
+
body = json.dumps(payload).encode("utf-8")
|
| 72 |
+
headers = {"Content-Type": "application/json"}
|
| 73 |
+
|
| 74 |
+
logger.info(f"正在通过 FlareSolverr 访问 {GROK_URL} ...")
|
| 75 |
+
logger.debug(f"FlareSolverr 地址: {url}")
|
| 76 |
+
|
| 77 |
+
req = urllib_request.Request(url, data=body, method="POST", headers=headers)
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
def _post():
|
| 81 |
+
with urllib_request.urlopen(req, timeout=cf_timeout + 30) as resp:
|
| 82 |
+
return json.loads(resp.read().decode("utf-8"))
|
| 83 |
+
|
| 84 |
+
result = await asyncio.to_thread(_post)
|
| 85 |
+
|
| 86 |
+
status = result.get("status", "")
|
| 87 |
+
if status != "ok":
|
| 88 |
+
message = result.get("message", "unknown error")
|
| 89 |
+
logger.error(f"FlareSolverr 返回错误: {status} - {message}")
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
solution = result.get("solution", {})
|
| 93 |
+
cookies = solution.get("cookies", [])
|
| 94 |
+
|
| 95 |
+
if not cookies:
|
| 96 |
+
logger.error("FlareSolverr 成功访问但没有返回 cookies")
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
cookie_str = _extract_all_cookies(cookies)
|
| 100 |
+
clearance = _extract_cookie_value(cookies, "cf_clearance")
|
| 101 |
+
ua = _extract_user_agent(solution)
|
| 102 |
+
browser = _extract_browser_profile(ua)
|
| 103 |
+
logger.info(f"成功获取 cookies (数量: {len(cookies)}), 指纹: {browser}")
|
| 104 |
+
|
| 105 |
+
return {
|
| 106 |
+
"cookies": cookie_str,
|
| 107 |
+
"cf_clearance": clearance,
|
| 108 |
+
"user_agent": ua,
|
| 109 |
+
"browser": browser,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
except HTTPError as e:
|
| 113 |
+
body_text = e.read().decode("utf-8", "replace")[:300]
|
| 114 |
+
logger.error(f"FlareSolverr 请求失败: {e.code} - {body_text}")
|
| 115 |
+
return None
|
| 116 |
+
except URLError as e:
|
| 117 |
+
logger.error(f"无法连接 FlareSolverr ({flaresolverr_url}): {e.reason}")
|
| 118 |
+
logger.info("请确认 FlareSolverr 服务已启动: docker run -p 8191:8191 ghcr.io/flaresolverr/flaresolverr:latest")
|
| 119 |
+
return None
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"请求异常: {e}")
|
| 122 |
+
return None
|
app/services/grok/batch_services/assets.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch assets service.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from app.core.config import get_config
|
| 9 |
+
from app.core.logger import logger
|
| 10 |
+
from app.services.reverse.assets_list import AssetsListReverse
|
| 11 |
+
from app.services.reverse.assets_delete import AssetsDeleteReverse
|
| 12 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 13 |
+
from app.core.batch import run_batch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseAssetsService:
|
| 17 |
+
"""Base assets service."""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self._session: Optional[ResettableSession] = None
|
| 21 |
+
|
| 22 |
+
async def _get_session(self) -> ResettableSession:
|
| 23 |
+
if self._session is None:
|
| 24 |
+
browser = get_config("proxy.browser")
|
| 25 |
+
if browser:
|
| 26 |
+
self._session = ResettableSession(impersonate=browser)
|
| 27 |
+
else:
|
| 28 |
+
self._session = ResettableSession()
|
| 29 |
+
return self._session
|
| 30 |
+
|
| 31 |
+
async def close(self):
|
| 32 |
+
if self._session:
|
| 33 |
+
await self._session.close()
|
| 34 |
+
self._session = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_LIST_SEMAPHORE = None
|
| 38 |
+
_LIST_SEM_VALUE = None
|
| 39 |
+
_DELETE_SEMAPHORE = None
|
| 40 |
+
_DELETE_SEM_VALUE = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _get_list_semaphore() -> asyncio.Semaphore:
|
| 44 |
+
value = max(1, int(get_config("asset.list_concurrent")))
|
| 45 |
+
global _LIST_SEMAPHORE, _LIST_SEM_VALUE
|
| 46 |
+
if _LIST_SEMAPHORE is None or value != _LIST_SEM_VALUE:
|
| 47 |
+
_LIST_SEM_VALUE = value
|
| 48 |
+
_LIST_SEMAPHORE = asyncio.Semaphore(value)
|
| 49 |
+
return _LIST_SEMAPHORE
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_delete_semaphore() -> asyncio.Semaphore:
|
| 53 |
+
value = max(1, int(get_config("asset.delete_concurrent")))
|
| 54 |
+
global _DELETE_SEMAPHORE, _DELETE_SEM_VALUE
|
| 55 |
+
if _DELETE_SEMAPHORE is None or value != _DELETE_SEM_VALUE:
|
| 56 |
+
_DELETE_SEM_VALUE = value
|
| 57 |
+
_DELETE_SEMAPHORE = asyncio.Semaphore(value)
|
| 58 |
+
return _DELETE_SEMAPHORE
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ListService(BaseAssetsService):
|
| 62 |
+
"""Assets list service."""
|
| 63 |
+
|
| 64 |
+
async def list(self, token: str) -> Dict[str, List[str] | int]:
|
| 65 |
+
params = {
|
| 66 |
+
"pageSize": 50,
|
| 67 |
+
"orderBy": "ORDER_BY_LAST_USE_TIME",
|
| 68 |
+
"source": "SOURCE_ANY",
|
| 69 |
+
"isLatest": "true",
|
| 70 |
+
}
|
| 71 |
+
page_token = None
|
| 72 |
+
seen_tokens = set()
|
| 73 |
+
asset_ids: List[str] = []
|
| 74 |
+
session = await self._get_session()
|
| 75 |
+
while True:
|
| 76 |
+
if page_token:
|
| 77 |
+
if page_token in seen_tokens:
|
| 78 |
+
logger.warning("Pagination stopped: repeated page token")
|
| 79 |
+
break
|
| 80 |
+
seen_tokens.add(page_token)
|
| 81 |
+
params["pageToken"] = page_token
|
| 82 |
+
else:
|
| 83 |
+
params.pop("pageToken", None)
|
| 84 |
+
|
| 85 |
+
async with _get_list_semaphore():
|
| 86 |
+
response = await AssetsListReverse.request(
|
| 87 |
+
session,
|
| 88 |
+
token,
|
| 89 |
+
params,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
result = response.json()
|
| 93 |
+
page_assets = result.get("assets", [])
|
| 94 |
+
if page_assets:
|
| 95 |
+
for asset in page_assets:
|
| 96 |
+
asset_id = asset.get("assetId")
|
| 97 |
+
if asset_id:
|
| 98 |
+
asset_ids.append(asset_id)
|
| 99 |
+
|
| 100 |
+
page_token = result.get("nextPageToken")
|
| 101 |
+
if not page_token:
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
logger.info(f"List success: {len(asset_ids)} files")
|
| 105 |
+
return {"asset_ids": asset_ids, "count": len(asset_ids)}
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
async def fetch_assets_details(
|
| 109 |
+
tokens: List[str],
|
| 110 |
+
account_map: dict,
|
| 111 |
+
*,
|
| 112 |
+
include_ok: bool = False,
|
| 113 |
+
on_item=None,
|
| 114 |
+
should_cancel=None,
|
| 115 |
+
) -> dict:
|
| 116 |
+
"""Batch fetch assets details for tokens."""
|
| 117 |
+
account_map = account_map or {}
|
| 118 |
+
shared_service = ListService()
|
| 119 |
+
batch_size = max(1, int(get_config("asset.list_batch_size")))
|
| 120 |
+
|
| 121 |
+
async def _fetch_detail(token: str):
|
| 122 |
+
account = account_map.get(token)
|
| 123 |
+
try:
|
| 124 |
+
result = await shared_service.list(token)
|
| 125 |
+
asset_ids = result.get("asset_ids", [])
|
| 126 |
+
count = result.get("count", len(asset_ids))
|
| 127 |
+
detail = {
|
| 128 |
+
"token": token,
|
| 129 |
+
"token_masked": account["token_masked"] if account else token,
|
| 130 |
+
"count": count,
|
| 131 |
+
"status": "ok",
|
| 132 |
+
"last_asset_clear_at": account["last_asset_clear_at"]
|
| 133 |
+
if account
|
| 134 |
+
else None,
|
| 135 |
+
}
|
| 136 |
+
if include_ok:
|
| 137 |
+
return {"ok": True, "detail": detail, "count": count}
|
| 138 |
+
return {"detail": detail, "count": count}
|
| 139 |
+
except Exception as e:
|
| 140 |
+
detail = {
|
| 141 |
+
"token": token,
|
| 142 |
+
"token_masked": account["token_masked"] if account else token,
|
| 143 |
+
"count": 0,
|
| 144 |
+
"status": f"error: {str(e)}",
|
| 145 |
+
"last_asset_clear_at": account["last_asset_clear_at"]
|
| 146 |
+
if account
|
| 147 |
+
else None,
|
| 148 |
+
}
|
| 149 |
+
if include_ok:
|
| 150 |
+
return {"ok": False, "detail": detail, "count": 0}
|
| 151 |
+
return {"detail": detail, "count": 0}
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
return await run_batch(
|
| 155 |
+
tokens,
|
| 156 |
+
_fetch_detail,
|
| 157 |
+
batch_size=batch_size,
|
| 158 |
+
on_item=on_item,
|
| 159 |
+
should_cancel=should_cancel,
|
| 160 |
+
)
|
| 161 |
+
finally:
|
| 162 |
+
await shared_service.close()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DeleteService(BaseAssetsService):
|
| 166 |
+
"""Assets delete service."""
|
| 167 |
+
|
| 168 |
+
async def delete(self, token: str, asset_ids: List[str]) -> Dict[str, int]:
|
| 169 |
+
if not asset_ids:
|
| 170 |
+
logger.info("No assets to delete")
|
| 171 |
+
return {"total": 0, "success": 0, "failed": 0, "skipped": True}
|
| 172 |
+
|
| 173 |
+
total = len(asset_ids)
|
| 174 |
+
success = 0
|
| 175 |
+
failed = 0
|
| 176 |
+
session = await self._get_session()
|
| 177 |
+
|
| 178 |
+
async def _delete_one(asset_id: str):
|
| 179 |
+
async with _get_delete_semaphore():
|
| 180 |
+
await AssetsDeleteReverse.request(session, token, asset_id)
|
| 181 |
+
|
| 182 |
+
tasks = [_delete_one(asset_id) for asset_id in asset_ids if asset_id]
|
| 183 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 184 |
+
for res in results:
|
| 185 |
+
if isinstance(res, Exception):
|
| 186 |
+
failed += 1
|
| 187 |
+
else:
|
| 188 |
+
success += 1
|
| 189 |
+
|
| 190 |
+
logger.info(f"Delete all: total={total}, success={success}, failed={failed}")
|
| 191 |
+
return {"total": total, "success": success, "failed": failed}
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
async def clear_assets(
|
| 195 |
+
tokens: List[str],
|
| 196 |
+
mgr,
|
| 197 |
+
*,
|
| 198 |
+
include_ok: bool = False,
|
| 199 |
+
on_item=None,
|
| 200 |
+
should_cancel=None,
|
| 201 |
+
) -> dict:
|
| 202 |
+
"""Batch clear assets for tokens."""
|
| 203 |
+
delete_service = DeleteService()
|
| 204 |
+
list_service = ListService()
|
| 205 |
+
batch_size = max(1, int(get_config("asset.delete_batch_size")))
|
| 206 |
+
|
| 207 |
+
async def _clear_one(token: str):
|
| 208 |
+
try:
|
| 209 |
+
result = await list_service.list(token)
|
| 210 |
+
asset_ids = result.get("asset_ids", [])
|
| 211 |
+
result = await delete_service.delete(token, asset_ids)
|
| 212 |
+
await mgr.mark_asset_clear(token)
|
| 213 |
+
if include_ok:
|
| 214 |
+
return {"ok": True, "result": result}
|
| 215 |
+
return {"status": "success", "result": result}
|
| 216 |
+
except Exception as e:
|
| 217 |
+
if include_ok:
|
| 218 |
+
return {"ok": False, "error": str(e)}
|
| 219 |
+
return {"status": "error", "error": str(e)}
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
return await run_batch(
|
| 223 |
+
tokens,
|
| 224 |
+
_clear_one,
|
| 225 |
+
batch_size=batch_size,
|
| 226 |
+
on_item=on_item,
|
| 227 |
+
should_cancel=should_cancel,
|
| 228 |
+
)
|
| 229 |
+
finally:
|
| 230 |
+
await delete_service.close()
|
| 231 |
+
await list_service.close()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
__all__ = ["ListService", "DeleteService"]
|
app/services/grok/batch_services/nsfw.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch NSFW service.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Callable, Awaitable, Dict, Any, Optional
|
| 7 |
+
|
| 8 |
+
from app.core.logger import logger
|
| 9 |
+
from app.core.config import get_config
|
| 10 |
+
from app.core.exceptions import UpstreamException
|
| 11 |
+
from app.services.reverse.accept_tos import AcceptTosReverse
|
| 12 |
+
from app.services.reverse.nsfw_mgmt import NsfwMgmtReverse
|
| 13 |
+
from app.services.reverse.set_birth import SetBirthReverse
|
| 14 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 15 |
+
from app.core.batch import run_batch
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_NSFW_SEMAPHORE = None
|
| 19 |
+
_NSFW_SEM_VALUE = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_nsfw_semaphore() -> asyncio.Semaphore:
|
| 23 |
+
value = max(1, int(get_config("nsfw.concurrent")))
|
| 24 |
+
global _NSFW_SEMAPHORE, _NSFW_SEM_VALUE
|
| 25 |
+
if _NSFW_SEMAPHORE is None or value != _NSFW_SEM_VALUE:
|
| 26 |
+
_NSFW_SEM_VALUE = value
|
| 27 |
+
_NSFW_SEMAPHORE = asyncio.Semaphore(value)
|
| 28 |
+
return _NSFW_SEMAPHORE
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class NSFWService:
|
| 32 |
+
"""NSFW 模式服务"""
|
| 33 |
+
@staticmethod
|
| 34 |
+
async def batch(
|
| 35 |
+
tokens: list[str],
|
| 36 |
+
mgr,
|
| 37 |
+
*,
|
| 38 |
+
on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
|
| 39 |
+
should_cancel: Optional[Callable[[], bool]] = None,
|
| 40 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 41 |
+
"""Batch enable NSFW."""
|
| 42 |
+
batch_size = get_config("nsfw.batch_size")
|
| 43 |
+
async def _enable(token: str):
|
| 44 |
+
try:
|
| 45 |
+
browser = get_config("proxy.browser")
|
| 46 |
+
async with ResettableSession(impersonate=browser) as session:
|
| 47 |
+
async def _record_fail(err: UpstreamException, reason: str):
|
| 48 |
+
status = None
|
| 49 |
+
if err.details and "status" in err.details:
|
| 50 |
+
status = err.details["status"]
|
| 51 |
+
else:
|
| 52 |
+
status = getattr(err, "status_code", None)
|
| 53 |
+
if status == 401:
|
| 54 |
+
await mgr.record_fail(token, status, reason)
|
| 55 |
+
return status or 0
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
async with _get_nsfw_semaphore():
|
| 59 |
+
await AcceptTosReverse.request(session, token)
|
| 60 |
+
except UpstreamException as e:
|
| 61 |
+
status = await _record_fail(e, "tos_auth_failed")
|
| 62 |
+
return {
|
| 63 |
+
"success": False,
|
| 64 |
+
"http_status": status,
|
| 65 |
+
"error": f"Accept ToS failed: {str(e)}",
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
async with _get_nsfw_semaphore():
|
| 70 |
+
await SetBirthReverse.request(session, token)
|
| 71 |
+
except UpstreamException as e:
|
| 72 |
+
status = await _record_fail(e, "set_birth_auth_failed")
|
| 73 |
+
return {
|
| 74 |
+
"success": False,
|
| 75 |
+
"http_status": status,
|
| 76 |
+
"error": f"Set birth date failed: {str(e)}",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
async with _get_nsfw_semaphore():
|
| 81 |
+
grpc_status = await NsfwMgmtReverse.request(session, token)
|
| 82 |
+
success = grpc_status.code in (-1, 0)
|
| 83 |
+
except UpstreamException as e:
|
| 84 |
+
status = await _record_fail(e, "nsfw_mgmt_auth_failed")
|
| 85 |
+
return {
|
| 86 |
+
"success": False,
|
| 87 |
+
"http_status": status,
|
| 88 |
+
"error": f"NSFW enable failed: {str(e)}",
|
| 89 |
+
}
|
| 90 |
+
if success:
|
| 91 |
+
await mgr.add_tag(token, "nsfw")
|
| 92 |
+
return {
|
| 93 |
+
"success": success,
|
| 94 |
+
"http_status": 200,
|
| 95 |
+
"grpc_status": grpc_status.code,
|
| 96 |
+
"grpc_message": grpc_status.message or None,
|
| 97 |
+
"error": None,
|
| 98 |
+
}
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"NSFW enable failed: {e}")
|
| 101 |
+
return {"success": False, "http_status": 0, "error": str(e)[:100]}
|
| 102 |
+
|
| 103 |
+
return await run_batch(
|
| 104 |
+
tokens,
|
| 105 |
+
_enable,
|
| 106 |
+
batch_size=batch_size,
|
| 107 |
+
on_item=on_item,
|
| 108 |
+
should_cancel=should_cancel,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
__all__ = ["NSFWService"]
|
app/services/grok/batch_services/usage.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch usage service.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Callable, Awaitable, Dict, Any, Optional, List
|
| 7 |
+
|
| 8 |
+
from app.core.logger import logger
|
| 9 |
+
from app.core.config import get_config
|
| 10 |
+
from app.services.reverse.rate_limits import RateLimitsReverse
|
| 11 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 12 |
+
from app.core.batch import run_batch
|
| 13 |
+
|
| 14 |
+
_USAGE_SEMAPHORE = None
|
| 15 |
+
_USAGE_SEM_VALUE = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_usage_semaphore() -> asyncio.Semaphore:
|
| 19 |
+
value = max(1, int(get_config("usage.concurrent")))
|
| 20 |
+
global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE
|
| 21 |
+
if _USAGE_SEMAPHORE is None or value != _USAGE_SEM_VALUE:
|
| 22 |
+
_USAGE_SEM_VALUE = value
|
| 23 |
+
_USAGE_SEMAPHORE = asyncio.Semaphore(value)
|
| 24 |
+
return _USAGE_SEMAPHORE
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class UsageService:
|
| 28 |
+
"""用量查询服务"""
|
| 29 |
+
|
| 30 |
+
async def get(self, token: str) -> Dict:
|
| 31 |
+
"""
|
| 32 |
+
获取速率限制信息
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
token: 认证 Token
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
响应数据
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
UpstreamException: 当获取失败且重试耗尽时
|
| 42 |
+
"""
|
| 43 |
+
async with _get_usage_semaphore():
|
| 44 |
+
try:
|
| 45 |
+
browser = get_config("proxy.browser")
|
| 46 |
+
if browser:
|
| 47 |
+
session_ctx = ResettableSession(impersonate=browser)
|
| 48 |
+
else:
|
| 49 |
+
session_ctx = ResettableSession()
|
| 50 |
+
async with session_ctx as session:
|
| 51 |
+
response = await RateLimitsReverse.request(session, token)
|
| 52 |
+
data = response.json()
|
| 53 |
+
remaining = data.get("remainingTokens")
|
| 54 |
+
if remaining is None:
|
| 55 |
+
remaining = data.get("remainingQueries")
|
| 56 |
+
if remaining is not None:
|
| 57 |
+
data["remainingTokens"] = remaining
|
| 58 |
+
logger.info(
|
| 59 |
+
f"Usage sync success: remaining={remaining}, token={token[:10]}..."
|
| 60 |
+
)
|
| 61 |
+
return data
|
| 62 |
+
|
| 63 |
+
except Exception:
|
| 64 |
+
# 最后一次失败已经被记录
|
| 65 |
+
raise
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
async def batch(
|
| 70 |
+
tokens: List[str],
|
| 71 |
+
mgr,
|
| 72 |
+
*,
|
| 73 |
+
on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
|
| 74 |
+
should_cancel: Optional[Callable[[], bool]] = None,
|
| 75 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 76 |
+
batch_size = get_config("usage.batch_size")
|
| 77 |
+
async def _refresh_one(t: str):
|
| 78 |
+
return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False)
|
| 79 |
+
|
| 80 |
+
return await run_batch(
|
| 81 |
+
tokens,
|
| 82 |
+
_refresh_one,
|
| 83 |
+
batch_size=batch_size,
|
| 84 |
+
on_item=on_item,
|
| 85 |
+
should_cancel=should_cancel,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
__all__ = ["UsageService"]
|
app/services/grok/defaults.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok 服务默认配置
|
| 3 |
+
|
| 4 |
+
此文件读取 config.defaults.toml,作为 Grok 服务的默认值来源。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import tomllib
|
| 9 |
+
|
| 10 |
+
from app.core.logger import logger
|
| 11 |
+
|
| 12 |
+
DEFAULTS_FILE = Path(__file__).resolve().parent.parent.parent.parent / "config.defaults.toml"
|
| 13 |
+
|
| 14 |
+
# Grok 服务默认配置(运行时从 config.defaults.toml 读取并缓存)
|
| 15 |
+
GROK_DEFAULTS: dict = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_grok_defaults():
|
| 19 |
+
"""获取 Grok 默认配置"""
|
| 20 |
+
global GROK_DEFAULTS
|
| 21 |
+
if GROK_DEFAULTS:
|
| 22 |
+
return GROK_DEFAULTS
|
| 23 |
+
if not DEFAULTS_FILE.exists():
|
| 24 |
+
logger.warning(f"Defaults file not found: {DEFAULTS_FILE}")
|
| 25 |
+
return GROK_DEFAULTS
|
| 26 |
+
try:
|
| 27 |
+
with DEFAULTS_FILE.open("rb") as f:
|
| 28 |
+
GROK_DEFAULTS = tomllib.load(f)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
logger.warning(f"Failed to load defaults from {DEFAULTS_FILE}: {e}")
|
| 31 |
+
return GROK_DEFAULTS
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
__all__ = ["GROK_DEFAULTS", "get_grok_defaults"]
|
app/services/grok/services/chat.py
ADDED
|
@@ -0,0 +1,1115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok Chat 服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import re
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Dict, List, Any, AsyncGenerator, AsyncIterable
|
| 9 |
+
|
| 10 |
+
import orjson
|
| 11 |
+
from curl_cffi.requests.errors import RequestsError
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
from app.core.config import get_config
|
| 15 |
+
from app.core.exceptions import (
|
| 16 |
+
AppException,
|
| 17 |
+
ValidationException,
|
| 18 |
+
ErrorType,
|
| 19 |
+
UpstreamException,
|
| 20 |
+
StreamIdleTimeoutError,
|
| 21 |
+
)
|
| 22 |
+
from app.services.grok.services.model import ModelService
|
| 23 |
+
from app.services.grok.utils.upload import UploadService
|
| 24 |
+
from app.services.grok.utils import process as proc_base
|
| 25 |
+
from app.services.grok.utils.retry import pick_token, rate_limited, transient_upstream
|
| 26 |
+
from app.services.reverse.app_chat import AppChatReverse
|
| 27 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 28 |
+
from app.services.grok.utils.stream import wrap_stream_with_usage
|
| 29 |
+
from app.services.grok.utils.tool_call import (
|
| 30 |
+
build_tool_prompt,
|
| 31 |
+
parse_tool_calls,
|
| 32 |
+
parse_tool_call_block,
|
| 33 |
+
format_tool_history,
|
| 34 |
+
)
|
| 35 |
+
from app.services.token import get_token_manager, EffortType
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_CHAT_SEMAPHORE = None
|
| 39 |
+
_CHAT_SEM_VALUE = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def extract_tool_text(raw: str, rollout_id: str = "") -> str:
|
| 43 |
+
if not raw:
|
| 44 |
+
return ""
|
| 45 |
+
name_match = re.search(
|
| 46 |
+
r"<xai:tool_name>(.*?)</xai:tool_name>", raw, flags=re.DOTALL
|
| 47 |
+
)
|
| 48 |
+
args_match = re.search(
|
| 49 |
+
r"<xai:tool_args>(.*?)</xai:tool_args>", raw, flags=re.DOTALL
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
name = name_match.group(1) if name_match else ""
|
| 53 |
+
if name:
|
| 54 |
+
name = re.sub(r"<!\[CDATA\[(.*?)\]\]>", r"\1", name, flags=re.DOTALL).strip()
|
| 55 |
+
|
| 56 |
+
args = args_match.group(1) if args_match else ""
|
| 57 |
+
if args:
|
| 58 |
+
args = re.sub(r"<!\[CDATA\[(.*?)\]\]>", r"\1", args, flags=re.DOTALL).strip()
|
| 59 |
+
|
| 60 |
+
payload = None
|
| 61 |
+
if args:
|
| 62 |
+
try:
|
| 63 |
+
payload = orjson.loads(args)
|
| 64 |
+
except orjson.JSONDecodeError:
|
| 65 |
+
payload = None
|
| 66 |
+
|
| 67 |
+
label = name
|
| 68 |
+
text = args
|
| 69 |
+
prefix = f"[{rollout_id}]" if rollout_id else ""
|
| 70 |
+
|
| 71 |
+
if name == "web_search":
|
| 72 |
+
label = f"{prefix}[WebSearch]"
|
| 73 |
+
if isinstance(payload, dict):
|
| 74 |
+
text = payload.get("query") or payload.get("q") or ""
|
| 75 |
+
elif name == "search_images":
|
| 76 |
+
label = f"{prefix}[SearchImage]"
|
| 77 |
+
if isinstance(payload, dict):
|
| 78 |
+
text = (
|
| 79 |
+
payload.get("image_description")
|
| 80 |
+
or payload.get("description")
|
| 81 |
+
or payload.get("query")
|
| 82 |
+
or ""
|
| 83 |
+
)
|
| 84 |
+
elif name == "chatroom_send":
|
| 85 |
+
label = f"{prefix}[AgentThink]"
|
| 86 |
+
if isinstance(payload, dict):
|
| 87 |
+
text = payload.get("message") or ""
|
| 88 |
+
|
| 89 |
+
if label and text:
|
| 90 |
+
return f"{label} {text}".strip()
|
| 91 |
+
if label:
|
| 92 |
+
return label
|
| 93 |
+
if text:
|
| 94 |
+
return text
|
| 95 |
+
# Fallback: strip tags to keep any raw text.
|
| 96 |
+
return re.sub(r"<[^>]+>", "", raw, flags=re.DOTALL).strip()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _get_chat_semaphore() -> asyncio.Semaphore:
|
| 100 |
+
global _CHAT_SEMAPHORE, _CHAT_SEM_VALUE
|
| 101 |
+
value = max(1, int(get_config("chat.concurrent")))
|
| 102 |
+
if value != _CHAT_SEM_VALUE:
|
| 103 |
+
_CHAT_SEM_VALUE = value
|
| 104 |
+
_CHAT_SEMAPHORE = asyncio.Semaphore(value)
|
| 105 |
+
return _CHAT_SEMAPHORE
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class MessageExtractor:
|
| 109 |
+
"""消息内容提取器"""
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def extract(
|
| 113 |
+
messages: List[Dict[str, Any]],
|
| 114 |
+
tools: List[Dict[str, Any]] = None,
|
| 115 |
+
tool_choice: Any = None,
|
| 116 |
+
parallel_tool_calls: bool = True,
|
| 117 |
+
) -> tuple[str, List[str], List[str]]:
|
| 118 |
+
"""从 OpenAI 消息格式提取内容,返回 (text, file_attachments, image_attachments)"""
|
| 119 |
+
# Pre-process: convert tool-related messages to text format
|
| 120 |
+
if tools:
|
| 121 |
+
messages = format_tool_history(messages)
|
| 122 |
+
|
| 123 |
+
texts = []
|
| 124 |
+
file_attachments: List[str] = []
|
| 125 |
+
image_attachments: List[str] = []
|
| 126 |
+
extracted = []
|
| 127 |
+
|
| 128 |
+
for msg in messages:
|
| 129 |
+
role = msg.get("role", "") or "user"
|
| 130 |
+
content = msg.get("content", "")
|
| 131 |
+
parts = []
|
| 132 |
+
|
| 133 |
+
if isinstance(content, str):
|
| 134 |
+
if content.strip():
|
| 135 |
+
parts.append(content)
|
| 136 |
+
elif isinstance(content, dict):
|
| 137 |
+
content = [content]
|
| 138 |
+
for item in content:
|
| 139 |
+
if not isinstance(item, dict):
|
| 140 |
+
continue
|
| 141 |
+
item_type = item.get("type", "")
|
| 142 |
+
if item_type == "text":
|
| 143 |
+
if text := item.get("text", "").strip():
|
| 144 |
+
parts.append(text)
|
| 145 |
+
elif item_type == "image_url":
|
| 146 |
+
image_data = item.get("image_url", {})
|
| 147 |
+
url = image_data.get("url", "")
|
| 148 |
+
if url:
|
| 149 |
+
image_attachments.append(url)
|
| 150 |
+
elif item_type == "input_audio":
|
| 151 |
+
audio_data = item.get("input_audio", {})
|
| 152 |
+
data = audio_data.get("data", "")
|
| 153 |
+
if data:
|
| 154 |
+
file_attachments.append(data)
|
| 155 |
+
elif item_type == "file":
|
| 156 |
+
file_data = item.get("file", {})
|
| 157 |
+
raw = file_data.get("file_data", "")
|
| 158 |
+
if raw:
|
| 159 |
+
file_attachments.append(raw)
|
| 160 |
+
elif isinstance(content, list):
|
| 161 |
+
for item in content:
|
| 162 |
+
if not isinstance(item, dict):
|
| 163 |
+
continue
|
| 164 |
+
item_type = item.get("type", "")
|
| 165 |
+
|
| 166 |
+
if item_type == "text":
|
| 167 |
+
if text := item.get("text", "").strip():
|
| 168 |
+
parts.append(text)
|
| 169 |
+
|
| 170 |
+
elif item_type == "image_url":
|
| 171 |
+
image_data = item.get("image_url", {})
|
| 172 |
+
url = image_data.get("url", "")
|
| 173 |
+
if url:
|
| 174 |
+
image_attachments.append(url)
|
| 175 |
+
|
| 176 |
+
elif item_type == "input_audio":
|
| 177 |
+
audio_data = item.get("input_audio", {})
|
| 178 |
+
data = audio_data.get("data", "")
|
| 179 |
+
if data:
|
| 180 |
+
file_attachments.append(data)
|
| 181 |
+
|
| 182 |
+
elif item_type == "file":
|
| 183 |
+
file_data = item.get("file", {})
|
| 184 |
+
raw = file_data.get("file_data", "")
|
| 185 |
+
if raw:
|
| 186 |
+
file_attachments.append(raw)
|
| 187 |
+
|
| 188 |
+
# 保留工具调用轨迹,避免部分客户端在多轮工具会话中丢失上下文顺序
|
| 189 |
+
tool_calls = msg.get("tool_calls")
|
| 190 |
+
if role == "assistant" and not parts and isinstance(tool_calls, list):
|
| 191 |
+
for call in tool_calls:
|
| 192 |
+
if not isinstance(call, dict):
|
| 193 |
+
continue
|
| 194 |
+
fn = call.get("function", {})
|
| 195 |
+
if not isinstance(fn, dict):
|
| 196 |
+
fn = {}
|
| 197 |
+
name = fn.get("name") or call.get("name") or "tool"
|
| 198 |
+
arguments = fn.get("arguments", "")
|
| 199 |
+
if isinstance(arguments, (dict, list)):
|
| 200 |
+
try:
|
| 201 |
+
arguments = orjson.dumps(arguments).decode()
|
| 202 |
+
except Exception:
|
| 203 |
+
arguments = str(arguments)
|
| 204 |
+
if not isinstance(arguments, str):
|
| 205 |
+
arguments = str(arguments)
|
| 206 |
+
arguments = arguments.strip()
|
| 207 |
+
parts.append(
|
| 208 |
+
f"[tool_call] {name} {arguments}".strip()
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if parts:
|
| 212 |
+
role_label = role
|
| 213 |
+
if role == "tool":
|
| 214 |
+
name = msg.get("name")
|
| 215 |
+
call_id = msg.get("tool_call_id")
|
| 216 |
+
if isinstance(name, str) and name.strip():
|
| 217 |
+
role_label = f"tool[{name.strip()}]"
|
| 218 |
+
if isinstance(call_id, str) and call_id.strip():
|
| 219 |
+
role_label = f"{role_label}#{call_id.strip()}"
|
| 220 |
+
extracted.append({"role": role_label, "text": "\n".join(parts)})
|
| 221 |
+
|
| 222 |
+
# 找到最后一条 user 消息
|
| 223 |
+
last_user_index = next(
|
| 224 |
+
(
|
| 225 |
+
i
|
| 226 |
+
for i in range(len(extracted) - 1, -1, -1)
|
| 227 |
+
if extracted[i]["role"] == "user"
|
| 228 |
+
),
|
| 229 |
+
None,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
for i, item in enumerate(extracted):
|
| 233 |
+
role = item["role"] or "user"
|
| 234 |
+
text = item["text"]
|
| 235 |
+
texts.append(text if i == last_user_index else f"{role}: {text}")
|
| 236 |
+
|
| 237 |
+
combined = "\n\n".join(texts)
|
| 238 |
+
|
| 239 |
+
# If there are attachments but no text, inject a fallback prompt.
|
| 240 |
+
if (not combined.strip()) and (file_attachments or image_attachments):
|
| 241 |
+
combined = "Refer to the following content:"
|
| 242 |
+
|
| 243 |
+
# Prepend tool system prompt if tools are provided
|
| 244 |
+
if tools:
|
| 245 |
+
tool_prompt = build_tool_prompt(tools, tool_choice, parallel_tool_calls)
|
| 246 |
+
if tool_prompt:
|
| 247 |
+
combined = f"{tool_prompt}\n\n{combined}"
|
| 248 |
+
|
| 249 |
+
return combined, file_attachments, image_attachments
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class GrokChatService:
|
| 253 |
+
"""Grok API 调用服务"""
|
| 254 |
+
|
| 255 |
+
async def chat(
|
| 256 |
+
self,
|
| 257 |
+
token: str,
|
| 258 |
+
message: str,
|
| 259 |
+
model: str,
|
| 260 |
+
mode: str = None,
|
| 261 |
+
stream: bool = None,
|
| 262 |
+
file_attachments: List[str] = None,
|
| 263 |
+
tool_overrides: Dict[str, Any] = None,
|
| 264 |
+
model_config_override: Dict[str, Any] = None,
|
| 265 |
+
):
|
| 266 |
+
"""发送聊天请求"""
|
| 267 |
+
if stream is None:
|
| 268 |
+
stream = get_config("app.stream")
|
| 269 |
+
|
| 270 |
+
logger.debug(
|
| 271 |
+
f"Chat request: model={model}, mode={mode}, stream={stream}, attachments={len(file_attachments or [])}"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
browser = get_config("proxy.browser")
|
| 275 |
+
semaphore = _get_chat_semaphore()
|
| 276 |
+
await semaphore.acquire()
|
| 277 |
+
session = ResettableSession(impersonate=browser)
|
| 278 |
+
try:
|
| 279 |
+
stream_response = await AppChatReverse.request(
|
| 280 |
+
session,
|
| 281 |
+
token,
|
| 282 |
+
message=message,
|
| 283 |
+
model=model,
|
| 284 |
+
mode=mode,
|
| 285 |
+
file_attachments=file_attachments,
|
| 286 |
+
tool_overrides=tool_overrides,
|
| 287 |
+
model_config_override=model_config_override,
|
| 288 |
+
)
|
| 289 |
+
logger.info(f"Chat connected: model={model}, stream={stream}")
|
| 290 |
+
except Exception:
|
| 291 |
+
try:
|
| 292 |
+
await session.close()
|
| 293 |
+
except Exception:
|
| 294 |
+
pass
|
| 295 |
+
semaphore.release()
|
| 296 |
+
raise
|
| 297 |
+
|
| 298 |
+
async def _stream():
|
| 299 |
+
try:
|
| 300 |
+
async for line in stream_response:
|
| 301 |
+
yield line
|
| 302 |
+
finally:
|
| 303 |
+
semaphore.release()
|
| 304 |
+
|
| 305 |
+
return _stream()
|
| 306 |
+
|
| 307 |
+
async def chat_openai(
|
| 308 |
+
self,
|
| 309 |
+
token: str,
|
| 310 |
+
model: str,
|
| 311 |
+
messages: List[Dict[str, Any]],
|
| 312 |
+
stream: bool = None,
|
| 313 |
+
reasoning_effort: str | None = None,
|
| 314 |
+
temperature: float = 0.8,
|
| 315 |
+
top_p: float = 0.95,
|
| 316 |
+
tools: List[Dict[str, Any]] = None,
|
| 317 |
+
tool_choice: Any = None,
|
| 318 |
+
parallel_tool_calls: bool = True,
|
| 319 |
+
):
|
| 320 |
+
"""OpenAI 兼容接口"""
|
| 321 |
+
model_info = ModelService.get(model)
|
| 322 |
+
if not model_info:
|
| 323 |
+
raise ValidationException(f"Unknown model: {model}")
|
| 324 |
+
|
| 325 |
+
grok_model = model_info.grok_model
|
| 326 |
+
mode = model_info.model_mode
|
| 327 |
+
# 提取消息和附件
|
| 328 |
+
message, file_attachments, image_attachments = MessageExtractor.extract(
|
| 329 |
+
messages, tools=tools, tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls
|
| 330 |
+
)
|
| 331 |
+
logger.debug(
|
| 332 |
+
"Extracted message length=%s, files=%s, images=%s",
|
| 333 |
+
len(message),
|
| 334 |
+
len(file_attachments),
|
| 335 |
+
len(image_attachments),
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 上传附件
|
| 339 |
+
file_ids: List[str] = []
|
| 340 |
+
image_ids: List[str] = []
|
| 341 |
+
if file_attachments or image_attachments:
|
| 342 |
+
upload_service = UploadService()
|
| 343 |
+
try:
|
| 344 |
+
for attach_data in file_attachments:
|
| 345 |
+
file_id, _ = await upload_service.upload_file(attach_data, token)
|
| 346 |
+
file_ids.append(file_id)
|
| 347 |
+
logger.debug(f"Attachment uploaded: type=file, file_id={file_id}")
|
| 348 |
+
for attach_data in image_attachments:
|
| 349 |
+
file_id, _ = await upload_service.upload_file(attach_data, token)
|
| 350 |
+
image_ids.append(file_id)
|
| 351 |
+
logger.debug(f"Attachment uploaded: type=image, file_id={file_id}")
|
| 352 |
+
finally:
|
| 353 |
+
await upload_service.close()
|
| 354 |
+
|
| 355 |
+
all_attachments = file_ids + image_ids
|
| 356 |
+
stream = stream if stream is not None else get_config("app.stream")
|
| 357 |
+
|
| 358 |
+
model_config_override = {
|
| 359 |
+
"temperature": temperature,
|
| 360 |
+
"topP": top_p,
|
| 361 |
+
}
|
| 362 |
+
if reasoning_effort is not None:
|
| 363 |
+
model_config_override["reasoningEffort"] = reasoning_effort
|
| 364 |
+
|
| 365 |
+
response = await self.chat(
|
| 366 |
+
token,
|
| 367 |
+
message,
|
| 368 |
+
grok_model,
|
| 369 |
+
mode,
|
| 370 |
+
stream,
|
| 371 |
+
file_attachments=all_attachments,
|
| 372 |
+
tool_overrides=None,
|
| 373 |
+
model_config_override=model_config_override,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
return response, stream, model
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class ChatService:
|
| 380 |
+
"""Chat 业务服务"""
|
| 381 |
+
|
| 382 |
+
@staticmethod
|
| 383 |
+
async def completions(
|
| 384 |
+
model: str,
|
| 385 |
+
messages: List[Dict[str, Any]],
|
| 386 |
+
stream: bool = None,
|
| 387 |
+
reasoning_effort: str | None = None,
|
| 388 |
+
temperature: float = 0.8,
|
| 389 |
+
top_p: float = 0.95,
|
| 390 |
+
tools: List[Dict[str, Any]] = None,
|
| 391 |
+
tool_choice: Any = None,
|
| 392 |
+
parallel_tool_calls: bool = True,
|
| 393 |
+
):
|
| 394 |
+
"""Chat Completions 入口"""
|
| 395 |
+
# 获取 token
|
| 396 |
+
token_mgr = await get_token_manager()
|
| 397 |
+
await token_mgr.reload_if_stale()
|
| 398 |
+
|
| 399 |
+
# 解析参数
|
| 400 |
+
if reasoning_effort is None:
|
| 401 |
+
show_think = get_config("app.thinking")
|
| 402 |
+
else:
|
| 403 |
+
show_think = reasoning_effort != "none"
|
| 404 |
+
is_stream = stream if stream is not None else get_config("app.stream")
|
| 405 |
+
|
| 406 |
+
# 跨 Token 重试循环
|
| 407 |
+
tried_tokens = set()
|
| 408 |
+
max_token_retries = int(get_config("retry.max_retry") or 3)
|
| 409 |
+
last_error = None
|
| 410 |
+
|
| 411 |
+
for attempt in range(max_token_retries):
|
| 412 |
+
# 选择 token
|
| 413 |
+
token = await pick_token(token_mgr, model, tried_tokens)
|
| 414 |
+
if not token:
|
| 415 |
+
if last_error:
|
| 416 |
+
raise last_error
|
| 417 |
+
raise AppException(
|
| 418 |
+
message="No available tokens. Please try again later.",
|
| 419 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 420 |
+
code="rate_limit_exceeded",
|
| 421 |
+
status_code=429,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
tried_tokens.add(token)
|
| 425 |
+
|
| 426 |
+
try:
|
| 427 |
+
# 请求 Grok
|
| 428 |
+
service = GrokChatService()
|
| 429 |
+
response, _, model_name = await service.chat_openai(
|
| 430 |
+
token,
|
| 431 |
+
model,
|
| 432 |
+
messages,
|
| 433 |
+
stream=is_stream,
|
| 434 |
+
reasoning_effort=reasoning_effort,
|
| 435 |
+
temperature=temperature,
|
| 436 |
+
top_p=top_p,
|
| 437 |
+
tools=tools,
|
| 438 |
+
tool_choice=tool_choice,
|
| 439 |
+
parallel_tool_calls=parallel_tool_calls,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# 处理响应
|
| 443 |
+
if is_stream:
|
| 444 |
+
logger.debug(f"Processing stream response: model={model}")
|
| 445 |
+
processor = StreamProcessor(model_name, token, show_think, tools=tools, tool_choice=tool_choice)
|
| 446 |
+
return wrap_stream_with_usage(
|
| 447 |
+
processor.process(response), token_mgr, token, model
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# 非流式
|
| 451 |
+
logger.debug(f"Processing non-stream response: model={model}")
|
| 452 |
+
result = await CollectProcessor(model_name, token, tools=tools, tool_choice=tool_choice).process(response)
|
| 453 |
+
try:
|
| 454 |
+
model_info = ModelService.get(model)
|
| 455 |
+
effort = (
|
| 456 |
+
EffortType.HIGH
|
| 457 |
+
if (model_info and model_info.cost.value == "high")
|
| 458 |
+
else EffortType.LOW
|
| 459 |
+
)
|
| 460 |
+
await token_mgr.consume(token, effort)
|
| 461 |
+
logger.info(f"Chat completed: model={model}, effort={effort.value}")
|
| 462 |
+
except Exception as e:
|
| 463 |
+
logger.warning(f"Failed to record usage: {e}")
|
| 464 |
+
return result
|
| 465 |
+
|
| 466 |
+
except UpstreamException as e:
|
| 467 |
+
last_error = e
|
| 468 |
+
|
| 469 |
+
if rate_limited(e):
|
| 470 |
+
# 配额不足,标记 token 为 cooling 并换 token 重试
|
| 471 |
+
await token_mgr.mark_rate_limited(token)
|
| 472 |
+
logger.warning(
|
| 473 |
+
f"Token {token[:10]}... rate limited (429), "
|
| 474 |
+
f"trying next token (attempt {attempt + 1}/{max_token_retries})"
|
| 475 |
+
)
|
| 476 |
+
continue
|
| 477 |
+
|
| 478 |
+
if transient_upstream(e):
|
| 479 |
+
has_alternative_token = False
|
| 480 |
+
for pool_name in ModelService.pool_candidates_for_model(model):
|
| 481 |
+
if token_mgr.get_token(pool_name, exclude=tried_tokens):
|
| 482 |
+
has_alternative_token = True
|
| 483 |
+
break
|
| 484 |
+
if not has_alternative_token:
|
| 485 |
+
raise
|
| 486 |
+
logger.warning(
|
| 487 |
+
f"Transient upstream error for token {token[:10]}..., "
|
| 488 |
+
f"trying next token (attempt {attempt + 1}/{max_token_retries}): {e}"
|
| 489 |
+
)
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
# 非 429 错误,不换 token,直接抛出
|
| 493 |
+
raise
|
| 494 |
+
|
| 495 |
+
# 所有 token 都 429,抛出最后的错误
|
| 496 |
+
if last_error:
|
| 497 |
+
raise last_error
|
| 498 |
+
raise AppException(
|
| 499 |
+
message="No available tokens. Please try again later.",
|
| 500 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 501 |
+
code="rate_limit_exceeded",
|
| 502 |
+
status_code=429,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class StreamProcessor(proc_base.BaseProcessor):
|
| 507 |
+
"""Stream response processor."""
|
| 508 |
+
|
| 509 |
+
def __init__(self, model: str, token: str = "", show_think: bool = None, tools: List[Dict[str, Any]] = None, tool_choice: Any = None):
|
| 510 |
+
super().__init__(model, token)
|
| 511 |
+
self.response_id: str = None
|
| 512 |
+
self.fingerprint: str = ""
|
| 513 |
+
self.rollout_id: str = ""
|
| 514 |
+
self.think_opened: bool = False
|
| 515 |
+
self.image_think_active: bool = False
|
| 516 |
+
self.role_sent: bool = False
|
| 517 |
+
self.filter_tags = get_config("app.filter_tags")
|
| 518 |
+
self.tool_usage_enabled = (
|
| 519 |
+
"xai:tool_usage_card" in (self.filter_tags or [])
|
| 520 |
+
)
|
| 521 |
+
self._tool_usage_opened = False
|
| 522 |
+
self._tool_usage_buffer = ""
|
| 523 |
+
|
| 524 |
+
self.show_think = bool(show_think)
|
| 525 |
+
self.tools = tools
|
| 526 |
+
self.tool_choice = tool_choice
|
| 527 |
+
self._tool_stream_enabled = bool(tools) and tool_choice != "none"
|
| 528 |
+
self._tool_state = "text"
|
| 529 |
+
self._tool_buffer = ""
|
| 530 |
+
self._tool_partial = ""
|
| 531 |
+
self._tool_calls_seen = False
|
| 532 |
+
self._tool_call_index = 0
|
| 533 |
+
|
| 534 |
+
def _with_tool_index(self, tool_call: Any) -> Any:
|
| 535 |
+
if not isinstance(tool_call, dict):
|
| 536 |
+
return tool_call
|
| 537 |
+
if tool_call.get("index") is None:
|
| 538 |
+
tool_call = dict(tool_call)
|
| 539 |
+
tool_call["index"] = self._tool_call_index
|
| 540 |
+
self._tool_call_index += 1
|
| 541 |
+
return tool_call
|
| 542 |
+
|
| 543 |
+
def _filter_tool_card(self, token: str) -> str:
|
| 544 |
+
if not token or not self.tool_usage_enabled:
|
| 545 |
+
return token
|
| 546 |
+
|
| 547 |
+
output_parts: list[str] = []
|
| 548 |
+
rest = token
|
| 549 |
+
start_tag = "<xai:tool_usage_card"
|
| 550 |
+
end_tag = "</xai:tool_usage_card>"
|
| 551 |
+
|
| 552 |
+
while rest:
|
| 553 |
+
if self._tool_usage_opened:
|
| 554 |
+
end_idx = rest.find(end_tag)
|
| 555 |
+
if end_idx == -1:
|
| 556 |
+
self._tool_usage_buffer += rest
|
| 557 |
+
return "".join(output_parts)
|
| 558 |
+
end_pos = end_idx + len(end_tag)
|
| 559 |
+
self._tool_usage_buffer += rest[:end_pos]
|
| 560 |
+
line = extract_tool_text(self._tool_usage_buffer, self.rollout_id)
|
| 561 |
+
if line:
|
| 562 |
+
if output_parts and not output_parts[-1].endswith("\n"):
|
| 563 |
+
output_parts[-1] += "\n"
|
| 564 |
+
output_parts.append(f"{line}\n")
|
| 565 |
+
self._tool_usage_buffer = ""
|
| 566 |
+
self._tool_usage_opened = False
|
| 567 |
+
rest = rest[end_pos:]
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
start_idx = rest.find(start_tag)
|
| 571 |
+
if start_idx == -1:
|
| 572 |
+
output_parts.append(rest)
|
| 573 |
+
break
|
| 574 |
+
|
| 575 |
+
if start_idx > 0:
|
| 576 |
+
output_parts.append(rest[:start_idx])
|
| 577 |
+
|
| 578 |
+
end_idx = rest.find(end_tag, start_idx)
|
| 579 |
+
if end_idx == -1:
|
| 580 |
+
self._tool_usage_opened = True
|
| 581 |
+
self._tool_usage_buffer = rest[start_idx:]
|
| 582 |
+
break
|
| 583 |
+
|
| 584 |
+
end_pos = end_idx + len(end_tag)
|
| 585 |
+
raw_card = rest[start_idx:end_pos]
|
| 586 |
+
line = extract_tool_text(raw_card, self.rollout_id)
|
| 587 |
+
if line:
|
| 588 |
+
if output_parts and not output_parts[-1].endswith("\n"):
|
| 589 |
+
output_parts[-1] += "\n"
|
| 590 |
+
output_parts.append(f"{line}\n")
|
| 591 |
+
rest = rest[end_pos:]
|
| 592 |
+
|
| 593 |
+
return "".join(output_parts)
|
| 594 |
+
|
| 595 |
+
def _filter_token(self, token: str) -> str:
|
| 596 |
+
"""Filter special tags in current token only."""
|
| 597 |
+
if not token:
|
| 598 |
+
return token
|
| 599 |
+
|
| 600 |
+
if self.tool_usage_enabled:
|
| 601 |
+
token = self._filter_tool_card(token)
|
| 602 |
+
if not token:
|
| 603 |
+
return ""
|
| 604 |
+
|
| 605 |
+
if not self.filter_tags:
|
| 606 |
+
return token
|
| 607 |
+
|
| 608 |
+
for tag in self.filter_tags:
|
| 609 |
+
if tag == "xai:tool_usage_card":
|
| 610 |
+
continue
|
| 611 |
+
if f"<{tag}" in token or f"</{tag}" in token:
|
| 612 |
+
return ""
|
| 613 |
+
|
| 614 |
+
return token
|
| 615 |
+
|
| 616 |
+
def _suffix_prefix(self, text: str, tag: str) -> int:
|
| 617 |
+
if not text or not tag:
|
| 618 |
+
return 0
|
| 619 |
+
max_keep = min(len(text), len(tag) - 1)
|
| 620 |
+
for keep in range(max_keep, 0, -1):
|
| 621 |
+
if text.endswith(tag[:keep]):
|
| 622 |
+
return keep
|
| 623 |
+
return 0
|
| 624 |
+
|
| 625 |
+
def _handle_tool_stream(self, chunk: str) -> list[tuple[str, Any]]:
|
| 626 |
+
events: list[tuple[str, Any]] = []
|
| 627 |
+
if not chunk:
|
| 628 |
+
return events
|
| 629 |
+
|
| 630 |
+
start_tag = "<tool_call>"
|
| 631 |
+
end_tag = "</tool_call>"
|
| 632 |
+
data = f"{self._tool_partial}{chunk}"
|
| 633 |
+
self._tool_partial = ""
|
| 634 |
+
|
| 635 |
+
while data:
|
| 636 |
+
if self._tool_state == "text":
|
| 637 |
+
start_idx = data.find(start_tag)
|
| 638 |
+
if start_idx == -1:
|
| 639 |
+
keep = self._suffix_prefix(data, start_tag)
|
| 640 |
+
emit = data[:-keep] if keep else data
|
| 641 |
+
if emit:
|
| 642 |
+
events.append(("text", emit))
|
| 643 |
+
self._tool_partial = data[-keep:] if keep else ""
|
| 644 |
+
break
|
| 645 |
+
|
| 646 |
+
before = data[:start_idx]
|
| 647 |
+
if before:
|
| 648 |
+
events.append(("text", before))
|
| 649 |
+
data = data[start_idx + len(start_tag) :]
|
| 650 |
+
self._tool_state = "tool"
|
| 651 |
+
continue
|
| 652 |
+
|
| 653 |
+
end_idx = data.find(end_tag)
|
| 654 |
+
if end_idx == -1:
|
| 655 |
+
keep = self._suffix_prefix(data, end_tag)
|
| 656 |
+
append = data[:-keep] if keep else data
|
| 657 |
+
if append:
|
| 658 |
+
self._tool_buffer += append
|
| 659 |
+
self._tool_partial = data[-keep:] if keep else ""
|
| 660 |
+
break
|
| 661 |
+
|
| 662 |
+
self._tool_buffer += data[:end_idx]
|
| 663 |
+
data = data[end_idx + len(end_tag) :]
|
| 664 |
+
tool_call = parse_tool_call_block(self._tool_buffer, self.tools)
|
| 665 |
+
if tool_call:
|
| 666 |
+
events.append(("tool", self._with_tool_index(tool_call)))
|
| 667 |
+
self._tool_calls_seen = True
|
| 668 |
+
self._tool_buffer = ""
|
| 669 |
+
self._tool_state = "text"
|
| 670 |
+
|
| 671 |
+
return events
|
| 672 |
+
|
| 673 |
+
def _flush_tool_stream(self) -> list[tuple[str, Any]]:
|
| 674 |
+
events: list[tuple[str, Any]] = []
|
| 675 |
+
if self._tool_state == "text":
|
| 676 |
+
if self._tool_partial:
|
| 677 |
+
events.append(("text", self._tool_partial))
|
| 678 |
+
self._tool_partial = ""
|
| 679 |
+
return events
|
| 680 |
+
|
| 681 |
+
raw = f"{self._tool_buffer}{self._tool_partial}"
|
| 682 |
+
tool_call = parse_tool_call_block(raw, self.tools)
|
| 683 |
+
if tool_call:
|
| 684 |
+
events.append(("tool", self._with_tool_index(tool_call)))
|
| 685 |
+
self._tool_calls_seen = True
|
| 686 |
+
elif raw:
|
| 687 |
+
events.append(("text", f"<tool_call>{raw}"))
|
| 688 |
+
self._tool_buffer = ""
|
| 689 |
+
self._tool_partial = ""
|
| 690 |
+
self._tool_state = "text"
|
| 691 |
+
return events
|
| 692 |
+
|
| 693 |
+
def _sse(self, content: str = "", role: str = None, finish: str = None, tool_calls: list = None) -> str:
|
| 694 |
+
"""Build SSE response."""
|
| 695 |
+
delta = {}
|
| 696 |
+
if role:
|
| 697 |
+
delta["role"] = role
|
| 698 |
+
delta["content"] = ""
|
| 699 |
+
elif tool_calls is not None:
|
| 700 |
+
delta["tool_calls"] = tool_calls
|
| 701 |
+
elif content:
|
| 702 |
+
delta["content"] = content
|
| 703 |
+
|
| 704 |
+
chunk = {
|
| 705 |
+
"id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
|
| 706 |
+
"object": "chat.completion.chunk",
|
| 707 |
+
"created": self.created,
|
| 708 |
+
"model": self.model,
|
| 709 |
+
"system_fingerprint": self.fingerprint,
|
| 710 |
+
"choices": [
|
| 711 |
+
{"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}
|
| 712 |
+
],
|
| 713 |
+
}
|
| 714 |
+
return f"data: {orjson.dumps(chunk).decode()}\n\n"
|
| 715 |
+
|
| 716 |
+
async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
|
| 717 |
+
"""Process stream response.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
response: AsyncIterable[bytes], async iterable of bytes
|
| 721 |
+
|
| 722 |
+
Returns:
|
| 723 |
+
AsyncGenerator[str, None], async generator of strings
|
| 724 |
+
"""
|
| 725 |
+
idle_timeout = get_config("chat.stream_timeout")
|
| 726 |
+
|
| 727 |
+
try:
|
| 728 |
+
async for line in proc_base._with_idle_timeout(
|
| 729 |
+
response, idle_timeout, self.model
|
| 730 |
+
):
|
| 731 |
+
line = proc_base._normalize_line(line)
|
| 732 |
+
if not line:
|
| 733 |
+
continue
|
| 734 |
+
try:
|
| 735 |
+
data = orjson.loads(line)
|
| 736 |
+
except orjson.JSONDecodeError:
|
| 737 |
+
continue
|
| 738 |
+
|
| 739 |
+
resp = data.get("result", {}).get("response", {})
|
| 740 |
+
is_thinking = bool(resp.get("isThinking"))
|
| 741 |
+
# isThinking controls <think> tagging
|
| 742 |
+
# when absent, treat as False
|
| 743 |
+
|
| 744 |
+
if (llm := resp.get("llmInfo")) and not self.fingerprint:
|
| 745 |
+
self.fingerprint = llm.get("modelHash", "")
|
| 746 |
+
if rid := resp.get("responseId"):
|
| 747 |
+
self.response_id = rid
|
| 748 |
+
if rid := resp.get("rolloutId"):
|
| 749 |
+
self.rollout_id = str(rid)
|
| 750 |
+
|
| 751 |
+
if not self.role_sent:
|
| 752 |
+
yield self._sse(role="assistant")
|
| 753 |
+
self.role_sent = True
|
| 754 |
+
|
| 755 |
+
if img := resp.get("streamingImageGenerationResponse"):
|
| 756 |
+
if not self.show_think:
|
| 757 |
+
continue
|
| 758 |
+
self.image_think_active = True
|
| 759 |
+
if not self.think_opened:
|
| 760 |
+
yield self._sse("<think>\n")
|
| 761 |
+
self.think_opened = True
|
| 762 |
+
idx = img.get("imageIndex", 0) + 1
|
| 763 |
+
progress = img.get("progress", 0)
|
| 764 |
+
yield self._sse(
|
| 765 |
+
f"正在生成第{idx}张图片中,当前进度{progress}%\n"
|
| 766 |
+
)
|
| 767 |
+
continue
|
| 768 |
+
|
| 769 |
+
if mr := resp.get("modelResponse"):
|
| 770 |
+
if self.image_think_active and self.think_opened:
|
| 771 |
+
yield self._sse("\n</think>\n")
|
| 772 |
+
self.think_opened = False
|
| 773 |
+
self.image_think_active = False
|
| 774 |
+
for url in proc_base._collect_images(mr):
|
| 775 |
+
parts = url.split("/")
|
| 776 |
+
img_id = parts[-2] if len(parts) >= 2 else "image"
|
| 777 |
+
dl_service = self._get_dl()
|
| 778 |
+
rendered = await dl_service.render_image(
|
| 779 |
+
url, self.token, img_id
|
| 780 |
+
)
|
| 781 |
+
yield self._sse(f"{rendered}\n")
|
| 782 |
+
|
| 783 |
+
if (
|
| 784 |
+
(meta := mr.get("metadata", {}))
|
| 785 |
+
.get("llm_info", {})
|
| 786 |
+
.get("modelHash")
|
| 787 |
+
):
|
| 788 |
+
self.fingerprint = meta["llm_info"]["modelHash"]
|
| 789 |
+
continue
|
| 790 |
+
|
| 791 |
+
if card := resp.get("cardAttachment"):
|
| 792 |
+
json_data = card.get("jsonData")
|
| 793 |
+
if isinstance(json_data, str) and json_data.strip():
|
| 794 |
+
try:
|
| 795 |
+
card_data = orjson.loads(json_data)
|
| 796 |
+
except orjson.JSONDecodeError:
|
| 797 |
+
card_data = None
|
| 798 |
+
if isinstance(card_data, dict):
|
| 799 |
+
image = card_data.get("image") or {}
|
| 800 |
+
original = image.get("original")
|
| 801 |
+
title = image.get("title") or ""
|
| 802 |
+
if original:
|
| 803 |
+
title_safe = title.replace("\n", " ").strip()
|
| 804 |
+
if title_safe:
|
| 805 |
+
yield self._sse(f"\n")
|
| 806 |
+
else:
|
| 807 |
+
yield self._sse(f"\n")
|
| 808 |
+
continue
|
| 809 |
+
|
| 810 |
+
if (token := resp.get("token")) is not None:
|
| 811 |
+
if not token:
|
| 812 |
+
continue
|
| 813 |
+
filtered = self._filter_token(token)
|
| 814 |
+
if not filtered:
|
| 815 |
+
continue
|
| 816 |
+
in_think = is_thinking or self.image_think_active
|
| 817 |
+
if in_think:
|
| 818 |
+
if not self.show_think:
|
| 819 |
+
continue
|
| 820 |
+
if not self.think_opened:
|
| 821 |
+
yield self._sse("<think>\n")
|
| 822 |
+
self.think_opened = True
|
| 823 |
+
else:
|
| 824 |
+
if self.think_opened:
|
| 825 |
+
yield self._sse("\n</think>\n")
|
| 826 |
+
self.think_opened = False
|
| 827 |
+
|
| 828 |
+
if in_think:
|
| 829 |
+
yield self._sse(filtered)
|
| 830 |
+
continue
|
| 831 |
+
|
| 832 |
+
if self._tool_stream_enabled:
|
| 833 |
+
for kind, payload in self._handle_tool_stream(filtered):
|
| 834 |
+
if kind == "text":
|
| 835 |
+
yield self._sse(payload)
|
| 836 |
+
elif kind == "tool":
|
| 837 |
+
yield self._sse(tool_calls=[payload])
|
| 838 |
+
continue
|
| 839 |
+
|
| 840 |
+
yield self._sse(filtered)
|
| 841 |
+
|
| 842 |
+
if self.think_opened:
|
| 843 |
+
yield self._sse("</think>\n")
|
| 844 |
+
|
| 845 |
+
if self._tool_stream_enabled:
|
| 846 |
+
for kind, payload in self._flush_tool_stream():
|
| 847 |
+
if kind == "text":
|
| 848 |
+
yield self._sse(payload)
|
| 849 |
+
elif kind == "tool":
|
| 850 |
+
yield self._sse(tool_calls=[payload])
|
| 851 |
+
finish_reason = "tool_calls" if self._tool_calls_seen else "stop"
|
| 852 |
+
yield self._sse(finish=finish_reason)
|
| 853 |
+
else:
|
| 854 |
+
yield self._sse(finish="stop")
|
| 855 |
+
|
| 856 |
+
yield "data: [DONE]\n\n"
|
| 857 |
+
except asyncio.CancelledError:
|
| 858 |
+
logger.debug("Stream cancelled by client", extra={"model": self.model})
|
| 859 |
+
except StreamIdleTimeoutError as e:
|
| 860 |
+
raise UpstreamException(
|
| 861 |
+
message=f"Stream idle timeout after {e.idle_seconds}s",
|
| 862 |
+
status_code=504,
|
| 863 |
+
details={
|
| 864 |
+
"error": str(e),
|
| 865 |
+
"type": "stream_idle_timeout",
|
| 866 |
+
"idle_seconds": e.idle_seconds,
|
| 867 |
+
},
|
| 868 |
+
)
|
| 869 |
+
except RequestsError as e:
|
| 870 |
+
if proc_base._is_http2_error(e):
|
| 871 |
+
logger.warning(f"HTTP/2 stream error: {e}", extra={"model": self.model})
|
| 872 |
+
raise UpstreamException(
|
| 873 |
+
message="Upstream connection closed unexpectedly",
|
| 874 |
+
status_code=502,
|
| 875 |
+
details={"error": str(e), "type": "http2_stream_error"},
|
| 876 |
+
)
|
| 877 |
+
logger.error(f"Stream request error: {e}", extra={"model": self.model})
|
| 878 |
+
raise UpstreamException(
|
| 879 |
+
message=f"Upstream request failed: {e}",
|
| 880 |
+
status_code=502,
|
| 881 |
+
details={"error": str(e)},
|
| 882 |
+
)
|
| 883 |
+
except Exception as e:
|
| 884 |
+
logger.error(
|
| 885 |
+
f"Stream processing error: {e}",
|
| 886 |
+
extra={"model": self.model, "error_type": type(e).__name__},
|
| 887 |
+
)
|
| 888 |
+
raise
|
| 889 |
+
finally:
|
| 890 |
+
await self.close()
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
class CollectProcessor(proc_base.BaseProcessor):
|
| 894 |
+
"""Non-stream response processor."""
|
| 895 |
+
|
| 896 |
+
def __init__(self, model: str, token: str = "", tools: List[Dict[str, Any]] = None, tool_choice: Any = None):
|
| 897 |
+
super().__init__(model, token)
|
| 898 |
+
self.filter_tags = get_config("app.filter_tags")
|
| 899 |
+
self.tools = tools
|
| 900 |
+
self.tool_choice = tool_choice
|
| 901 |
+
|
| 902 |
+
def _filter_content(self, content: str) -> str:
|
| 903 |
+
"""Filter special tags in content."""
|
| 904 |
+
if not content or not self.filter_tags:
|
| 905 |
+
return content
|
| 906 |
+
|
| 907 |
+
result = content
|
| 908 |
+
if "xai:tool_usage_card" in self.filter_tags:
|
| 909 |
+
rollout_id = ""
|
| 910 |
+
rollout_match = re.search(
|
| 911 |
+
r"<rolloutId>(.*?)</rolloutId>", result, flags=re.DOTALL
|
| 912 |
+
)
|
| 913 |
+
if rollout_match:
|
| 914 |
+
rollout_id = rollout_match.group(1).strip()
|
| 915 |
+
|
| 916 |
+
result = re.sub(
|
| 917 |
+
r"<xai:tool_usage_card[^>]*>.*?</xai:tool_usage_card>",
|
| 918 |
+
lambda match: (
|
| 919 |
+
f"{extract_tool_text(match.group(0), rollout_id)}\n"
|
| 920 |
+
if extract_tool_text(match.group(0), rollout_id)
|
| 921 |
+
else ""
|
| 922 |
+
),
|
| 923 |
+
result,
|
| 924 |
+
flags=re.DOTALL,
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
for tag in self.filter_tags:
|
| 928 |
+
if tag == "xai:tool_usage_card":
|
| 929 |
+
continue
|
| 930 |
+
pattern = rf"<{re.escape(tag)}[^>]*>.*?</{re.escape(tag)}>|<{re.escape(tag)}[^>]*/>"
|
| 931 |
+
result = re.sub(pattern, "", result, flags=re.DOTALL)
|
| 932 |
+
|
| 933 |
+
return result
|
| 934 |
+
|
| 935 |
+
async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
|
| 936 |
+
"""Process and collect full response."""
|
| 937 |
+
response_id = ""
|
| 938 |
+
fingerprint = ""
|
| 939 |
+
content = ""
|
| 940 |
+
idle_timeout = get_config("chat.stream_timeout")
|
| 941 |
+
|
| 942 |
+
try:
|
| 943 |
+
async for line in proc_base._with_idle_timeout(
|
| 944 |
+
response, idle_timeout, self.model
|
| 945 |
+
):
|
| 946 |
+
line = proc_base._normalize_line(line)
|
| 947 |
+
if not line:
|
| 948 |
+
continue
|
| 949 |
+
try:
|
| 950 |
+
data = orjson.loads(line)
|
| 951 |
+
except orjson.JSONDecodeError:
|
| 952 |
+
continue
|
| 953 |
+
|
| 954 |
+
resp = data.get("result", {}).get("response", {})
|
| 955 |
+
|
| 956 |
+
if (llm := resp.get("llmInfo")) and not fingerprint:
|
| 957 |
+
fingerprint = llm.get("modelHash", "")
|
| 958 |
+
|
| 959 |
+
if mr := resp.get("modelResponse"):
|
| 960 |
+
response_id = mr.get("responseId", "")
|
| 961 |
+
content = mr.get("message", "")
|
| 962 |
+
|
| 963 |
+
card_map: dict[str, tuple[str, str]] = {}
|
| 964 |
+
for raw in mr.get("cardAttachmentsJson") or []:
|
| 965 |
+
if not isinstance(raw, str) or not raw.strip():
|
| 966 |
+
continue
|
| 967 |
+
try:
|
| 968 |
+
card_data = orjson.loads(raw)
|
| 969 |
+
except orjson.JSONDecodeError:
|
| 970 |
+
continue
|
| 971 |
+
if not isinstance(card_data, dict):
|
| 972 |
+
continue
|
| 973 |
+
card_id = card_data.get("id")
|
| 974 |
+
image = card_data.get("image") or {}
|
| 975 |
+
original = image.get("original")
|
| 976 |
+
if not card_id or not original:
|
| 977 |
+
continue
|
| 978 |
+
title = image.get("title") or ""
|
| 979 |
+
card_map[card_id] = (title, original)
|
| 980 |
+
|
| 981 |
+
if content and card_map:
|
| 982 |
+
def _render_card(match: re.Match) -> str:
|
| 983 |
+
card_id = match.group(1)
|
| 984 |
+
item = card_map.get(card_id)
|
| 985 |
+
if not item:
|
| 986 |
+
return ""
|
| 987 |
+
title, original = item
|
| 988 |
+
title_safe = title.replace("\n", " ").strip() or "image"
|
| 989 |
+
prefix = ""
|
| 990 |
+
if match.start() > 0:
|
| 991 |
+
prev = content[match.start() - 1]
|
| 992 |
+
if prev not in ("\n", "\r"):
|
| 993 |
+
prefix = "\n"
|
| 994 |
+
return f"{prefix}"
|
| 995 |
+
|
| 996 |
+
content = re.sub(
|
| 997 |
+
r'<grok:render[^>]*card_id="([^"]+)"[^>]*>.*?</grok:render>',
|
| 998 |
+
_render_card,
|
| 999 |
+
content,
|
| 1000 |
+
flags=re.DOTALL,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
if urls := proc_base._collect_images(mr):
|
| 1004 |
+
content += "\n"
|
| 1005 |
+
for url in urls:
|
| 1006 |
+
parts = url.split("/")
|
| 1007 |
+
img_id = parts[-2] if len(parts) >= 2 else "image"
|
| 1008 |
+
dl_service = self._get_dl()
|
| 1009 |
+
rendered = await dl_service.render_image(
|
| 1010 |
+
url, self.token, img_id
|
| 1011 |
+
)
|
| 1012 |
+
content += f"{rendered}\n"
|
| 1013 |
+
|
| 1014 |
+
if (
|
| 1015 |
+
(meta := mr.get("metadata", {}))
|
| 1016 |
+
.get("llm_info", {})
|
| 1017 |
+
.get("modelHash")
|
| 1018 |
+
):
|
| 1019 |
+
fingerprint = meta["llm_info"]["modelHash"]
|
| 1020 |
+
|
| 1021 |
+
except asyncio.CancelledError:
|
| 1022 |
+
logger.debug("Collect cancelled by client", extra={"model": self.model})
|
| 1023 |
+
raise
|
| 1024 |
+
except StreamIdleTimeoutError as e:
|
| 1025 |
+
logger.warning(f"Collect idle timeout: {e}", extra={"model": self.model})
|
| 1026 |
+
raise UpstreamException(
|
| 1027 |
+
message=f"Collect stream idle timeout after {e.idle_seconds}s",
|
| 1028 |
+
details={
|
| 1029 |
+
"error": str(e),
|
| 1030 |
+
"type": "stream_idle_timeout",
|
| 1031 |
+
"idle_seconds": e.idle_seconds,
|
| 1032 |
+
"status": 504,
|
| 1033 |
+
},
|
| 1034 |
+
)
|
| 1035 |
+
except RequestsError as e:
|
| 1036 |
+
if proc_base._is_http2_error(e):
|
| 1037 |
+
logger.warning(
|
| 1038 |
+
f"HTTP/2 stream error in collect: {e}", extra={"model": self.model}
|
| 1039 |
+
)
|
| 1040 |
+
raise UpstreamException(
|
| 1041 |
+
message="Upstream connection closed unexpectedly",
|
| 1042 |
+
details={"error": str(e), "type": "http2_stream_error", "status": 502},
|
| 1043 |
+
)
|
| 1044 |
+
logger.error(f"Collect request error: {e}", extra={"model": self.model})
|
| 1045 |
+
raise UpstreamException(
|
| 1046 |
+
message=f"Upstream request failed: {e}",
|
| 1047 |
+
details={"error": str(e), "status": 502},
|
| 1048 |
+
)
|
| 1049 |
+
except Exception as e:
|
| 1050 |
+
logger.error(
|
| 1051 |
+
f"Collect processing error: {e}",
|
| 1052 |
+
extra={"model": self.model, "error_type": type(e).__name__},
|
| 1053 |
+
)
|
| 1054 |
+
raise
|
| 1055 |
+
finally:
|
| 1056 |
+
await self.close()
|
| 1057 |
+
|
| 1058 |
+
content = self._filter_content(content)
|
| 1059 |
+
|
| 1060 |
+
# Parse for tool calls if tools were provided
|
| 1061 |
+
finish_reason = "stop"
|
| 1062 |
+
tool_calls_result = None
|
| 1063 |
+
if self.tools and self.tool_choice != "none":
|
| 1064 |
+
text_content, tool_calls_list = parse_tool_calls(content, self.tools)
|
| 1065 |
+
if tool_calls_list:
|
| 1066 |
+
tool_calls_result = tool_calls_list
|
| 1067 |
+
content = text_content # May be None
|
| 1068 |
+
finish_reason = "tool_calls"
|
| 1069 |
+
|
| 1070 |
+
message_obj = {
|
| 1071 |
+
"role": "assistant",
|
| 1072 |
+
"content": content,
|
| 1073 |
+
"refusal": None,
|
| 1074 |
+
"annotations": [],
|
| 1075 |
+
}
|
| 1076 |
+
if tool_calls_result:
|
| 1077 |
+
message_obj["tool_calls"] = tool_calls_result
|
| 1078 |
+
|
| 1079 |
+
return {
|
| 1080 |
+
"id": response_id,
|
| 1081 |
+
"object": "chat.completion",
|
| 1082 |
+
"created": self.created,
|
| 1083 |
+
"model": self.model,
|
| 1084 |
+
"system_fingerprint": fingerprint,
|
| 1085 |
+
"choices": [
|
| 1086 |
+
{
|
| 1087 |
+
"index": 0,
|
| 1088 |
+
"message": message_obj,
|
| 1089 |
+
"finish_reason": finish_reason,
|
| 1090 |
+
}
|
| 1091 |
+
],
|
| 1092 |
+
"usage": {
|
| 1093 |
+
"prompt_tokens": 0,
|
| 1094 |
+
"completion_tokens": 0,
|
| 1095 |
+
"total_tokens": 0,
|
| 1096 |
+
"prompt_tokens_details": {
|
| 1097 |
+
"cached_tokens": 0,
|
| 1098 |
+
"text_tokens": 0,
|
| 1099 |
+
"audio_tokens": 0,
|
| 1100 |
+
"image_tokens": 0,
|
| 1101 |
+
},
|
| 1102 |
+
"completion_tokens_details": {
|
| 1103 |
+
"text_tokens": 0,
|
| 1104 |
+
"audio_tokens": 0,
|
| 1105 |
+
"reasoning_tokens": 0,
|
| 1106 |
+
},
|
| 1107 |
+
},
|
| 1108 |
+
}
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
__all__ = [
|
| 1112 |
+
"GrokChatService",
|
| 1113 |
+
"MessageExtractor",
|
| 1114 |
+
"ChatService",
|
| 1115 |
+
]
|
app/services/grok/services/image.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok image services.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import base64
|
| 7 |
+
import math
|
| 8 |
+
import time
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union
|
| 12 |
+
|
| 13 |
+
import orjson
|
| 14 |
+
|
| 15 |
+
from app.core.config import get_config
|
| 16 |
+
from app.core.logger import logger
|
| 17 |
+
from app.core.storage import DATA_DIR
|
| 18 |
+
from app.core.exceptions import AppException, ErrorType, UpstreamException
|
| 19 |
+
from app.services.grok.utils.process import BaseProcessor
|
| 20 |
+
from app.services.grok.utils.retry import pick_token, rate_limited
|
| 21 |
+
from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content
|
| 22 |
+
from app.services.grok.utils.stream import wrap_stream_with_usage
|
| 23 |
+
from app.services.token import EffortType
|
| 24 |
+
from app.services.reverse.ws_imagine import ImagineWebSocketReverse
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
image_service = ImagineWebSocketReverse()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ImageGenerationResult:
|
| 32 |
+
stream: bool
|
| 33 |
+
data: Union[AsyncGenerator[str, None], List[str]]
|
| 34 |
+
usage_override: Optional[dict] = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ImageGenerationService:
|
| 38 |
+
"""Image generation orchestration service."""
|
| 39 |
+
|
| 40 |
+
async def generate(
|
| 41 |
+
self,
|
| 42 |
+
*,
|
| 43 |
+
token_mgr: Any,
|
| 44 |
+
token: str,
|
| 45 |
+
model_info: Any,
|
| 46 |
+
prompt: str,
|
| 47 |
+
n: int,
|
| 48 |
+
response_format: str,
|
| 49 |
+
size: str,
|
| 50 |
+
aspect_ratio: str,
|
| 51 |
+
stream: bool,
|
| 52 |
+
enable_nsfw: Optional[bool] = None,
|
| 53 |
+
chat_format: bool = False,
|
| 54 |
+
) -> ImageGenerationResult:
|
| 55 |
+
max_token_retries = int(get_config("retry.max_retry") or 3)
|
| 56 |
+
tried_tokens: set[str] = set()
|
| 57 |
+
last_error: Optional[Exception] = None
|
| 58 |
+
|
| 59 |
+
# resolve nsfw once for routing and upstream
|
| 60 |
+
if enable_nsfw is None:
|
| 61 |
+
enable_nsfw = bool(get_config("image.nsfw"))
|
| 62 |
+
prefer_tags = {"nsfw"} if enable_nsfw else None
|
| 63 |
+
|
| 64 |
+
if stream:
|
| 65 |
+
|
| 66 |
+
async def _stream_retry() -> AsyncGenerator[str, None]:
|
| 67 |
+
nonlocal last_error
|
| 68 |
+
for attempt in range(max_token_retries):
|
| 69 |
+
preferred = token if (attempt == 0 and not prefer_tags) else None
|
| 70 |
+
current_token = await pick_token(
|
| 71 |
+
token_mgr,
|
| 72 |
+
model_info.model_id,
|
| 73 |
+
tried_tokens,
|
| 74 |
+
preferred=preferred,
|
| 75 |
+
prefer_tags=prefer_tags,
|
| 76 |
+
)
|
| 77 |
+
if not current_token:
|
| 78 |
+
if last_error:
|
| 79 |
+
raise last_error
|
| 80 |
+
raise AppException(
|
| 81 |
+
message="No available tokens. Please try again later.",
|
| 82 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 83 |
+
code="rate_limit_exceeded",
|
| 84 |
+
status_code=429,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
tried_tokens.add(current_token)
|
| 88 |
+
yielded = False
|
| 89 |
+
try:
|
| 90 |
+
result = await self._stream_ws(
|
| 91 |
+
token_mgr=token_mgr,
|
| 92 |
+
token=current_token,
|
| 93 |
+
model_info=model_info,
|
| 94 |
+
prompt=prompt,
|
| 95 |
+
n=n,
|
| 96 |
+
response_format=response_format,
|
| 97 |
+
size=size,
|
| 98 |
+
aspect_ratio=aspect_ratio,
|
| 99 |
+
enable_nsfw=enable_nsfw,
|
| 100 |
+
chat_format=chat_format,
|
| 101 |
+
)
|
| 102 |
+
async for chunk in result.data:
|
| 103 |
+
yielded = True
|
| 104 |
+
yield chunk
|
| 105 |
+
return
|
| 106 |
+
except UpstreamException as e:
|
| 107 |
+
last_error = e
|
| 108 |
+
if rate_limited(e):
|
| 109 |
+
if yielded:
|
| 110 |
+
raise
|
| 111 |
+
await token_mgr.mark_rate_limited(current_token)
|
| 112 |
+
logger.warning(
|
| 113 |
+
f"Token {current_token[:10]}... rate limited (429), "
|
| 114 |
+
f"trying next token (attempt {attempt + 1}/{max_token_retries})"
|
| 115 |
+
)
|
| 116 |
+
continue
|
| 117 |
+
raise
|
| 118 |
+
|
| 119 |
+
if last_error:
|
| 120 |
+
raise last_error
|
| 121 |
+
raise AppException(
|
| 122 |
+
message="No available tokens. Please try again later.",
|
| 123 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 124 |
+
code="rate_limit_exceeded",
|
| 125 |
+
status_code=429,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return ImageGenerationResult(stream=True, data=_stream_retry())
|
| 129 |
+
|
| 130 |
+
for attempt in range(max_token_retries):
|
| 131 |
+
preferred = token if (attempt == 0 and not prefer_tags) else None
|
| 132 |
+
current_token = await pick_token(
|
| 133 |
+
token_mgr,
|
| 134 |
+
model_info.model_id,
|
| 135 |
+
tried_tokens,
|
| 136 |
+
preferred=preferred,
|
| 137 |
+
prefer_tags=prefer_tags,
|
| 138 |
+
)
|
| 139 |
+
if not current_token:
|
| 140 |
+
if last_error:
|
| 141 |
+
raise last_error
|
| 142 |
+
raise AppException(
|
| 143 |
+
message="No available tokens. Please try again later.",
|
| 144 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 145 |
+
code="rate_limit_exceeded",
|
| 146 |
+
status_code=429,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
tried_tokens.add(current_token)
|
| 150 |
+
try:
|
| 151 |
+
return await self._collect_ws(
|
| 152 |
+
token_mgr=token_mgr,
|
| 153 |
+
token=current_token,
|
| 154 |
+
model_info=model_info,
|
| 155 |
+
tried_tokens=tried_tokens,
|
| 156 |
+
prompt=prompt,
|
| 157 |
+
n=n,
|
| 158 |
+
response_format=response_format,
|
| 159 |
+
aspect_ratio=aspect_ratio,
|
| 160 |
+
enable_nsfw=enable_nsfw,
|
| 161 |
+
)
|
| 162 |
+
except UpstreamException as e:
|
| 163 |
+
last_error = e
|
| 164 |
+
if rate_limited(e):
|
| 165 |
+
await token_mgr.mark_rate_limited(current_token)
|
| 166 |
+
logger.warning(
|
| 167 |
+
f"Token {current_token[:10]}... rate limited (429), "
|
| 168 |
+
f"trying next token (attempt {attempt + 1}/{max_token_retries})"
|
| 169 |
+
)
|
| 170 |
+
continue
|
| 171 |
+
raise
|
| 172 |
+
|
| 173 |
+
if last_error:
|
| 174 |
+
raise last_error
|
| 175 |
+
raise AppException(
|
| 176 |
+
message="No available tokens. Please try again later.",
|
| 177 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 178 |
+
code="rate_limit_exceeded",
|
| 179 |
+
status_code=429,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
async def _stream_ws(
|
| 183 |
+
self,
|
| 184 |
+
*,
|
| 185 |
+
token_mgr: Any,
|
| 186 |
+
token: str,
|
| 187 |
+
model_info: Any,
|
| 188 |
+
prompt: str,
|
| 189 |
+
n: int,
|
| 190 |
+
response_format: str,
|
| 191 |
+
size: str,
|
| 192 |
+
aspect_ratio: str,
|
| 193 |
+
enable_nsfw: Optional[bool] = None,
|
| 194 |
+
chat_format: bool = False,
|
| 195 |
+
) -> ImageGenerationResult:
|
| 196 |
+
if enable_nsfw is None:
|
| 197 |
+
enable_nsfw = bool(get_config("image.nsfw"))
|
| 198 |
+
stream_retries = int(get_config("image.blocked_parallel_attempts") or 5) + 1
|
| 199 |
+
stream_retries = max(1, min(stream_retries, 10))
|
| 200 |
+
upstream = image_service.stream(
|
| 201 |
+
token=token,
|
| 202 |
+
prompt=prompt,
|
| 203 |
+
aspect_ratio=aspect_ratio,
|
| 204 |
+
n=n,
|
| 205 |
+
enable_nsfw=enable_nsfw,
|
| 206 |
+
max_retries=stream_retries,
|
| 207 |
+
)
|
| 208 |
+
processor = ImageWSStreamProcessor(
|
| 209 |
+
model_info.model_id,
|
| 210 |
+
token,
|
| 211 |
+
n=n,
|
| 212 |
+
response_format=response_format,
|
| 213 |
+
size=size,
|
| 214 |
+
chat_format=chat_format,
|
| 215 |
+
)
|
| 216 |
+
stream = wrap_stream_with_usage(
|
| 217 |
+
processor.process(upstream),
|
| 218 |
+
token_mgr,
|
| 219 |
+
token,
|
| 220 |
+
model_info.model_id,
|
| 221 |
+
)
|
| 222 |
+
return ImageGenerationResult(stream=True, data=stream)
|
| 223 |
+
|
| 224 |
+
async def _collect_ws(
|
| 225 |
+
self,
|
| 226 |
+
*,
|
| 227 |
+
token_mgr: Any,
|
| 228 |
+
token: str,
|
| 229 |
+
model_info: Any,
|
| 230 |
+
tried_tokens: set[str],
|
| 231 |
+
prompt: str,
|
| 232 |
+
n: int,
|
| 233 |
+
response_format: str,
|
| 234 |
+
aspect_ratio: str,
|
| 235 |
+
enable_nsfw: Optional[bool] = None,
|
| 236 |
+
) -> ImageGenerationResult:
|
| 237 |
+
if enable_nsfw is None:
|
| 238 |
+
enable_nsfw = bool(get_config("image.nsfw"))
|
| 239 |
+
all_images: List[str] = []
|
| 240 |
+
seen = set()
|
| 241 |
+
expected_per_call = 6
|
| 242 |
+
calls_needed = max(1, int(math.ceil(n / expected_per_call)))
|
| 243 |
+
calls_needed = min(calls_needed, n)
|
| 244 |
+
|
| 245 |
+
async def _fetch_batch(call_target: int, call_token: str):
|
| 246 |
+
stream_retries = int(get_config("image.blocked_parallel_attempts") or 5) + 1
|
| 247 |
+
stream_retries = max(1, min(stream_retries, 10))
|
| 248 |
+
upstream = image_service.stream(
|
| 249 |
+
token=call_token,
|
| 250 |
+
prompt=prompt,
|
| 251 |
+
aspect_ratio=aspect_ratio,
|
| 252 |
+
n=call_target,
|
| 253 |
+
enable_nsfw=enable_nsfw,
|
| 254 |
+
max_retries=stream_retries,
|
| 255 |
+
)
|
| 256 |
+
processor = ImageWSCollectProcessor(
|
| 257 |
+
model_info.model_id,
|
| 258 |
+
token,
|
| 259 |
+
n=call_target,
|
| 260 |
+
response_format=response_format,
|
| 261 |
+
)
|
| 262 |
+
return await processor.process(upstream)
|
| 263 |
+
|
| 264 |
+
tasks = []
|
| 265 |
+
for i in range(calls_needed):
|
| 266 |
+
remaining = n - (i * expected_per_call)
|
| 267 |
+
call_target = min(expected_per_call, remaining)
|
| 268 |
+
tasks.append(_fetch_batch(call_target, token))
|
| 269 |
+
|
| 270 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 271 |
+
for batch in results:
|
| 272 |
+
if isinstance(batch, Exception):
|
| 273 |
+
logger.warning(f"WS batch failed: {batch}")
|
| 274 |
+
continue
|
| 275 |
+
for img in batch:
|
| 276 |
+
if img not in seen:
|
| 277 |
+
seen.add(img)
|
| 278 |
+
all_images.append(img)
|
| 279 |
+
if len(all_images) >= n:
|
| 280 |
+
break
|
| 281 |
+
if len(all_images) >= n:
|
| 282 |
+
break
|
| 283 |
+
|
| 284 |
+
# If upstream likely blocked/reviewed some images, run extra parallel attempts
|
| 285 |
+
# and only keep valid finals selected by ws_imagine classification.
|
| 286 |
+
if len(all_images) < n:
|
| 287 |
+
remaining = n - len(all_images)
|
| 288 |
+
extra_attempts = int(get_config("image.blocked_parallel_attempts") or 5)
|
| 289 |
+
extra_attempts = max(0, min(extra_attempts, 10))
|
| 290 |
+
parallel_enabled = bool(get_config("image.blocked_parallel_enabled", True))
|
| 291 |
+
if extra_attempts > 0:
|
| 292 |
+
logger.warning(
|
| 293 |
+
f"Image finals insufficient ({len(all_images)}/{n}), running "
|
| 294 |
+
f"{extra_attempts} recovery attempts for remaining={remaining}, "
|
| 295 |
+
f"parallel_enabled={parallel_enabled}"
|
| 296 |
+
)
|
| 297 |
+
extra_tasks = []
|
| 298 |
+
if parallel_enabled:
|
| 299 |
+
recovery_tried = set(tried_tokens)
|
| 300 |
+
recovery_tokens: List[str] = []
|
| 301 |
+
for _ in range(extra_attempts):
|
| 302 |
+
recovery_token = await pick_token(
|
| 303 |
+
token_mgr,
|
| 304 |
+
model_info.model_id,
|
| 305 |
+
recovery_tried,
|
| 306 |
+
)
|
| 307 |
+
if not recovery_token:
|
| 308 |
+
break
|
| 309 |
+
recovery_tried.add(recovery_token)
|
| 310 |
+
recovery_tokens.append(recovery_token)
|
| 311 |
+
|
| 312 |
+
if recovery_tokens:
|
| 313 |
+
logger.info(
|
| 314 |
+
f"Recovery using {len(recovery_tokens)} distinct tokens"
|
| 315 |
+
)
|
| 316 |
+
for recovery_token in recovery_tokens:
|
| 317 |
+
extra_tasks.append(
|
| 318 |
+
_fetch_batch(min(expected_per_call, remaining), recovery_token)
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
extra_tasks = [
|
| 322 |
+
_fetch_batch(min(expected_per_call, remaining), token)
|
| 323 |
+
for _ in range(extra_attempts)
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
if not extra_tasks:
|
| 327 |
+
logger.warning("No tokens available for recovery attempts")
|
| 328 |
+
extra_results = []
|
| 329 |
+
else:
|
| 330 |
+
extra_results = await asyncio.gather(*extra_tasks, return_exceptions=True)
|
| 331 |
+
for batch in extra_results:
|
| 332 |
+
if isinstance(batch, Exception):
|
| 333 |
+
logger.warning(f"WS recovery batch failed: {batch}")
|
| 334 |
+
continue
|
| 335 |
+
for img in batch:
|
| 336 |
+
if img not in seen:
|
| 337 |
+
seen.add(img)
|
| 338 |
+
all_images.append(img)
|
| 339 |
+
if len(all_images) >= n:
|
| 340 |
+
break
|
| 341 |
+
if len(all_images) >= n:
|
| 342 |
+
break
|
| 343 |
+
logger.info(
|
| 344 |
+
f"Image recovery attempts completed: finals={len(all_images)}/{n}, "
|
| 345 |
+
f"attempts={extra_attempts}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if len(all_images) < n:
|
| 349 |
+
logger.error(
|
| 350 |
+
f"Image generation failed after recovery attempts: finals={len(all_images)}/{n}, "
|
| 351 |
+
f"blocked_parallel_attempts={int(get_config('image.blocked_parallel_attempts') or 5)}"
|
| 352 |
+
)
|
| 353 |
+
raise UpstreamException(
|
| 354 |
+
"Image generation blocked or no valid final image",
|
| 355 |
+
details={
|
| 356 |
+
"error_code": "blocked_no_final_image",
|
| 357 |
+
"final_images": len(all_images),
|
| 358 |
+
"requested": n,
|
| 359 |
+
},
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
await token_mgr.consume(token, self._get_effort(model_info))
|
| 364 |
+
except Exception as e:
|
| 365 |
+
logger.warning(f"Failed to consume token: {e}")
|
| 366 |
+
|
| 367 |
+
selected = self._select_images(all_images, n)
|
| 368 |
+
usage_override = {
|
| 369 |
+
"total_tokens": 0,
|
| 370 |
+
"input_tokens": 0,
|
| 371 |
+
"output_tokens": 0,
|
| 372 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 373 |
+
}
|
| 374 |
+
return ImageGenerationResult(
|
| 375 |
+
stream=False, data=selected, usage_override=usage_override
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
@staticmethod
|
| 379 |
+
def _get_effort(model_info: Any) -> EffortType:
|
| 380 |
+
return (
|
| 381 |
+
EffortType.HIGH
|
| 382 |
+
if (model_info and model_info.cost.value == "high")
|
| 383 |
+
else EffortType.LOW
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def _select_images(images: List[str], n: int) -> List[str]:
|
| 388 |
+
if len(images) >= n:
|
| 389 |
+
return images[:n]
|
| 390 |
+
selected = images.copy()
|
| 391 |
+
while len(selected) < n:
|
| 392 |
+
selected.append("error")
|
| 393 |
+
return selected
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class ImageWSBaseProcessor(BaseProcessor):
|
| 397 |
+
"""WebSocket image processor base."""
|
| 398 |
+
|
| 399 |
+
def __init__(self, model: str, token: str = "", response_format: str = "b64_json"):
|
| 400 |
+
if response_format == "base64":
|
| 401 |
+
response_format = "b64_json"
|
| 402 |
+
super().__init__(model, token)
|
| 403 |
+
self.response_format = response_format
|
| 404 |
+
if response_format == "url":
|
| 405 |
+
self.response_field = "url"
|
| 406 |
+
elif response_format == "base64":
|
| 407 |
+
self.response_field = "base64"
|
| 408 |
+
else:
|
| 409 |
+
self.response_field = "b64_json"
|
| 410 |
+
self._image_dir: Optional[Path] = None
|
| 411 |
+
|
| 412 |
+
def _ensure_image_dir(self) -> Path:
|
| 413 |
+
if self._image_dir is None:
|
| 414 |
+
base_dir = DATA_DIR / "tmp" / "image"
|
| 415 |
+
base_dir.mkdir(parents=True, exist_ok=True)
|
| 416 |
+
self._image_dir = base_dir
|
| 417 |
+
return self._image_dir
|
| 418 |
+
|
| 419 |
+
def _strip_base64(self, blob: str) -> str:
|
| 420 |
+
if not blob:
|
| 421 |
+
return ""
|
| 422 |
+
if "," in blob and "base64" in blob.split(",", 1)[0]:
|
| 423 |
+
return blob.split(",", 1)[1]
|
| 424 |
+
return blob
|
| 425 |
+
|
| 426 |
+
def _guess_ext(self, blob: str) -> Optional[str]:
|
| 427 |
+
if not blob:
|
| 428 |
+
return None
|
| 429 |
+
header = ""
|
| 430 |
+
data = blob
|
| 431 |
+
if "," in blob and "base64" in blob.split(",", 1)[0]:
|
| 432 |
+
header, data = blob.split(",", 1)
|
| 433 |
+
header = header.lower()
|
| 434 |
+
if "image/png" in header:
|
| 435 |
+
return "png"
|
| 436 |
+
if "image/jpeg" in header or "image/jpg" in header:
|
| 437 |
+
return "jpg"
|
| 438 |
+
if data.startswith("iVBORw0KGgo"):
|
| 439 |
+
return "png"
|
| 440 |
+
if data.startswith("/9j/"):
|
| 441 |
+
return "jpg"
|
| 442 |
+
return None
|
| 443 |
+
|
| 444 |
+
def _filename(self, image_id: str, is_final: bool, ext: Optional[str] = None) -> str:
|
| 445 |
+
if ext:
|
| 446 |
+
ext = ext.lower()
|
| 447 |
+
if ext == "jpeg":
|
| 448 |
+
ext = "jpg"
|
| 449 |
+
if not ext:
|
| 450 |
+
ext = "jpg" if is_final else "png"
|
| 451 |
+
return f"{image_id}.{ext}"
|
| 452 |
+
|
| 453 |
+
def _build_file_url(self, filename: str) -> str:
|
| 454 |
+
app_url = get_config("app.app_url")
|
| 455 |
+
if app_url:
|
| 456 |
+
return f"{app_url.rstrip('/')}/v1/files/image/{filename}"
|
| 457 |
+
return f"/v1/files/image/{filename}"
|
| 458 |
+
|
| 459 |
+
async def _save_blob(
|
| 460 |
+
self, image_id: str, blob: str, is_final: bool, ext: Optional[str] = None
|
| 461 |
+
) -> str:
|
| 462 |
+
data = self._strip_base64(blob)
|
| 463 |
+
if not data:
|
| 464 |
+
return ""
|
| 465 |
+
image_dir = self._ensure_image_dir()
|
| 466 |
+
ext = ext or self._guess_ext(blob)
|
| 467 |
+
filename = self._filename(image_id, is_final, ext=ext)
|
| 468 |
+
filepath = image_dir / filename
|
| 469 |
+
|
| 470 |
+
def _write_file():
|
| 471 |
+
with open(filepath, "wb") as f:
|
| 472 |
+
f.write(base64.b64decode(data))
|
| 473 |
+
|
| 474 |
+
await asyncio.to_thread(_write_file)
|
| 475 |
+
return self._build_file_url(filename)
|
| 476 |
+
|
| 477 |
+
def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict:
|
| 478 |
+
if not existing:
|
| 479 |
+
return incoming
|
| 480 |
+
if incoming.get("is_final") and not existing.get("is_final"):
|
| 481 |
+
return incoming
|
| 482 |
+
if existing.get("is_final") and not incoming.get("is_final"):
|
| 483 |
+
return existing
|
| 484 |
+
if incoming.get("blob_size", 0) > existing.get("blob_size", 0):
|
| 485 |
+
return incoming
|
| 486 |
+
return existing
|
| 487 |
+
|
| 488 |
+
async def _to_output(self, image_id: str, item: Dict) -> str:
|
| 489 |
+
try:
|
| 490 |
+
if self.response_format == "url":
|
| 491 |
+
return await self._save_blob(
|
| 492 |
+
image_id,
|
| 493 |
+
item.get("blob", ""),
|
| 494 |
+
item.get("is_final", False),
|
| 495 |
+
ext=item.get("ext"),
|
| 496 |
+
)
|
| 497 |
+
return self._strip_base64(item.get("blob", ""))
|
| 498 |
+
except Exception as e:
|
| 499 |
+
logger.warning(f"Image output failed: {e}")
|
| 500 |
+
return ""
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class ImageWSStreamProcessor(ImageWSBaseProcessor):
|
| 504 |
+
"""WebSocket image stream processor."""
|
| 505 |
+
|
| 506 |
+
def __init__(
|
| 507 |
+
self,
|
| 508 |
+
model: str,
|
| 509 |
+
token: str = "",
|
| 510 |
+
n: int = 1,
|
| 511 |
+
response_format: str = "b64_json",
|
| 512 |
+
size: str = "1024x1024",
|
| 513 |
+
chat_format: bool = False,
|
| 514 |
+
):
|
| 515 |
+
super().__init__(model, token, response_format)
|
| 516 |
+
self.n = n
|
| 517 |
+
self.size = size
|
| 518 |
+
self.chat_format = chat_format
|
| 519 |
+
self._target_id: Optional[str] = None
|
| 520 |
+
self._index_map: Dict[str, int] = {}
|
| 521 |
+
self._partial_map: Dict[str, int] = {}
|
| 522 |
+
self._initial_sent: set[str] = set()
|
| 523 |
+
self._id_generated: bool = False
|
| 524 |
+
self._response_id: str = ""
|
| 525 |
+
|
| 526 |
+
def _assign_index(self, image_id: str) -> Optional[int]:
|
| 527 |
+
if image_id in self._index_map:
|
| 528 |
+
return self._index_map[image_id]
|
| 529 |
+
if len(self._index_map) >= self.n:
|
| 530 |
+
return None
|
| 531 |
+
self._index_map[image_id] = len(self._index_map)
|
| 532 |
+
return self._index_map[image_id]
|
| 533 |
+
|
| 534 |
+
def _sse(self, event: str, data: dict) -> str:
|
| 535 |
+
return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
|
| 536 |
+
|
| 537 |
+
async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]:
|
| 538 |
+
images: Dict[str, Dict] = {}
|
| 539 |
+
emitted_chat_chunk = False
|
| 540 |
+
|
| 541 |
+
async for item in response:
|
| 542 |
+
if item.get("type") == "error":
|
| 543 |
+
message = item.get("error") or "Upstream error"
|
| 544 |
+
code = item.get("error_code") or "upstream_error"
|
| 545 |
+
status = item.get("status")
|
| 546 |
+
if code == "rate_limit_exceeded" or status == 429:
|
| 547 |
+
raise UpstreamException(message, details=item)
|
| 548 |
+
yield self._sse(
|
| 549 |
+
"error",
|
| 550 |
+
{
|
| 551 |
+
"error": {
|
| 552 |
+
"message": message,
|
| 553 |
+
"type": "server_error",
|
| 554 |
+
"code": code,
|
| 555 |
+
}
|
| 556 |
+
},
|
| 557 |
+
)
|
| 558 |
+
return
|
| 559 |
+
if item.get("type") != "image":
|
| 560 |
+
continue
|
| 561 |
+
|
| 562 |
+
image_id = item.get("image_id")
|
| 563 |
+
if not image_id:
|
| 564 |
+
continue
|
| 565 |
+
|
| 566 |
+
if self.n == 1:
|
| 567 |
+
if self._target_id is None:
|
| 568 |
+
self._target_id = image_id
|
| 569 |
+
index = 0 if image_id == self._target_id else None
|
| 570 |
+
else:
|
| 571 |
+
index = self._assign_index(image_id)
|
| 572 |
+
|
| 573 |
+
images[image_id] = self._pick_best(images.get(image_id), item)
|
| 574 |
+
|
| 575 |
+
if index is None:
|
| 576 |
+
continue
|
| 577 |
+
|
| 578 |
+
if item.get("stage") != "final":
|
| 579 |
+
# Chat Completions image stream should only expose final results.
|
| 580 |
+
if self.chat_format:
|
| 581 |
+
continue
|
| 582 |
+
if image_id not in self._initial_sent:
|
| 583 |
+
self._initial_sent.add(image_id)
|
| 584 |
+
stage = item.get("stage") or "preview"
|
| 585 |
+
if stage == "medium":
|
| 586 |
+
partial_index = 1
|
| 587 |
+
self._partial_map[image_id] = 1
|
| 588 |
+
else:
|
| 589 |
+
partial_index = 0
|
| 590 |
+
self._partial_map[image_id] = 0
|
| 591 |
+
else:
|
| 592 |
+
stage = item.get("stage") or "partial"
|
| 593 |
+
if stage == "preview":
|
| 594 |
+
continue
|
| 595 |
+
partial_index = self._partial_map.get(image_id, 0)
|
| 596 |
+
if stage == "medium":
|
| 597 |
+
partial_index = max(partial_index, 1)
|
| 598 |
+
self._partial_map[image_id] = partial_index
|
| 599 |
+
|
| 600 |
+
if self.response_format == "url":
|
| 601 |
+
partial_id = f"{image_id}-{stage}-{partial_index}"
|
| 602 |
+
partial_out = await self._save_blob(
|
| 603 |
+
partial_id,
|
| 604 |
+
item.get("blob", ""),
|
| 605 |
+
False,
|
| 606 |
+
ext=item.get("ext"),
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
partial_out = self._strip_base64(item.get("blob", ""))
|
| 610 |
+
|
| 611 |
+
if self.chat_format and partial_out:
|
| 612 |
+
partial_out = wrap_image_content(partial_out, self.response_format)
|
| 613 |
+
|
| 614 |
+
if not partial_out:
|
| 615 |
+
continue
|
| 616 |
+
|
| 617 |
+
if self.chat_format:
|
| 618 |
+
# OpenAI ChatCompletion chunk format for partial
|
| 619 |
+
if not self._id_generated:
|
| 620 |
+
self._response_id = make_response_id()
|
| 621 |
+
self._id_generated = True
|
| 622 |
+
emitted_chat_chunk = True
|
| 623 |
+
yield self._sse(
|
| 624 |
+
"chat.completion.chunk",
|
| 625 |
+
make_chat_chunk(
|
| 626 |
+
self._response_id,
|
| 627 |
+
self.model,
|
| 628 |
+
partial_out,
|
| 629 |
+
index=index,
|
| 630 |
+
),
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
# Original image_generation format
|
| 634 |
+
yield self._sse(
|
| 635 |
+
"image_generation.partial_image",
|
| 636 |
+
{
|
| 637 |
+
"type": "image_generation.partial_image",
|
| 638 |
+
self.response_field: partial_out,
|
| 639 |
+
"created_at": int(time.time()),
|
| 640 |
+
"size": self.size,
|
| 641 |
+
"index": index,
|
| 642 |
+
"partial_image_index": partial_index,
|
| 643 |
+
"image_id": image_id,
|
| 644 |
+
"stage": stage,
|
| 645 |
+
},
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if self.n == 1:
|
| 649 |
+
target_item = images.get(self._target_id) if self._target_id else None
|
| 650 |
+
if target_item and target_item.get("is_final", False):
|
| 651 |
+
selected = [(self._target_id, target_item)]
|
| 652 |
+
elif images:
|
| 653 |
+
selected = [
|
| 654 |
+
max(
|
| 655 |
+
images.items(),
|
| 656 |
+
key=lambda x: (
|
| 657 |
+
x[1].get("is_final", False),
|
| 658 |
+
x[1].get("blob_size", 0),
|
| 659 |
+
),
|
| 660 |
+
)
|
| 661 |
+
]
|
| 662 |
+
else:
|
| 663 |
+
selected = []
|
| 664 |
+
else:
|
| 665 |
+
selected = [
|
| 666 |
+
(image_id, images[image_id])
|
| 667 |
+
for image_id in self._index_map
|
| 668 |
+
if image_id in images and images[image_id].get("is_final", False)
|
| 669 |
+
]
|
| 670 |
+
|
| 671 |
+
for image_id, item in selected:
|
| 672 |
+
if self.response_format == "url":
|
| 673 |
+
final_image_id = image_id
|
| 674 |
+
# Keep original imagine image name for imagine chat stream output.
|
| 675 |
+
if self.model != "grok-imagine-1.0-fast":
|
| 676 |
+
final_image_id = f"{image_id}-final"
|
| 677 |
+
output = await self._save_blob(
|
| 678 |
+
final_image_id,
|
| 679 |
+
item.get("blob", ""),
|
| 680 |
+
item.get("is_final", False),
|
| 681 |
+
ext=item.get("ext"),
|
| 682 |
+
)
|
| 683 |
+
if self.chat_format and output:
|
| 684 |
+
output = wrap_image_content(output, self.response_format)
|
| 685 |
+
else:
|
| 686 |
+
output = await self._to_output(image_id, item)
|
| 687 |
+
if self.chat_format and output:
|
| 688 |
+
output = wrap_image_content(output, self.response_format)
|
| 689 |
+
|
| 690 |
+
if not output:
|
| 691 |
+
continue
|
| 692 |
+
|
| 693 |
+
if self.n == 1:
|
| 694 |
+
index = 0
|
| 695 |
+
else:
|
| 696 |
+
index = self._index_map.get(image_id, 0)
|
| 697 |
+
|
| 698 |
+
if not self._id_generated:
|
| 699 |
+
self._response_id = make_response_id()
|
| 700 |
+
self._id_generated = True
|
| 701 |
+
|
| 702 |
+
if self.chat_format:
|
| 703 |
+
# OpenAI ChatCompletion chunk format
|
| 704 |
+
emitted_chat_chunk = True
|
| 705 |
+
yield self._sse(
|
| 706 |
+
"chat.completion.chunk",
|
| 707 |
+
make_chat_chunk(
|
| 708 |
+
self._response_id,
|
| 709 |
+
self.model,
|
| 710 |
+
output,
|
| 711 |
+
index=index,
|
| 712 |
+
is_final=True,
|
| 713 |
+
),
|
| 714 |
+
)
|
| 715 |
+
else:
|
| 716 |
+
# Original image_generation format
|
| 717 |
+
yield self._sse(
|
| 718 |
+
"image_generation.completed",
|
| 719 |
+
{
|
| 720 |
+
"type": "image_generation.completed",
|
| 721 |
+
self.response_field: output,
|
| 722 |
+
"created_at": int(time.time()),
|
| 723 |
+
"size": self.size,
|
| 724 |
+
"index": index,
|
| 725 |
+
"image_id": image_id,
|
| 726 |
+
"stage": "final",
|
| 727 |
+
"usage": {
|
| 728 |
+
"total_tokens": 0,
|
| 729 |
+
"input_tokens": 0,
|
| 730 |
+
"output_tokens": 0,
|
| 731 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 732 |
+
},
|
| 733 |
+
},
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if self.chat_format:
|
| 737 |
+
if not self._id_generated:
|
| 738 |
+
self._response_id = make_response_id()
|
| 739 |
+
self._id_generated = True
|
| 740 |
+
if not emitted_chat_chunk:
|
| 741 |
+
yield self._sse(
|
| 742 |
+
"chat.completion.chunk",
|
| 743 |
+
make_chat_chunk(
|
| 744 |
+
self._response_id,
|
| 745 |
+
self.model,
|
| 746 |
+
"",
|
| 747 |
+
index=0,
|
| 748 |
+
is_final=True,
|
| 749 |
+
),
|
| 750 |
+
)
|
| 751 |
+
yield "data: [DONE]\n\n"
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
class ImageWSCollectProcessor(ImageWSBaseProcessor):
|
| 755 |
+
"""WebSocket image non-stream processor."""
|
| 756 |
+
|
| 757 |
+
def __init__(
|
| 758 |
+
self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json"
|
| 759 |
+
):
|
| 760 |
+
super().__init__(model, token, response_format)
|
| 761 |
+
self.n = n
|
| 762 |
+
|
| 763 |
+
async def process(self, response: AsyncIterable[dict]) -> List[str]:
|
| 764 |
+
images: Dict[str, Dict] = {}
|
| 765 |
+
|
| 766 |
+
async for item in response:
|
| 767 |
+
if item.get("type") == "error":
|
| 768 |
+
message = item.get("error") or "Upstream error"
|
| 769 |
+
raise UpstreamException(message, details=item)
|
| 770 |
+
if item.get("type") != "image":
|
| 771 |
+
continue
|
| 772 |
+
image_id = item.get("image_id")
|
| 773 |
+
if not image_id:
|
| 774 |
+
continue
|
| 775 |
+
images[image_id] = self._pick_best(images.get(image_id), item)
|
| 776 |
+
|
| 777 |
+
selected = sorted(
|
| 778 |
+
[item for item in images.values() if item.get("is_final", False)],
|
| 779 |
+
key=lambda x: x.get("blob_size", 0),
|
| 780 |
+
reverse=True,
|
| 781 |
+
)
|
| 782 |
+
if self.n:
|
| 783 |
+
selected = selected[: self.n]
|
| 784 |
+
|
| 785 |
+
results: List[str] = []
|
| 786 |
+
for item in selected:
|
| 787 |
+
output = await self._to_output(item.get("image_id", ""), item)
|
| 788 |
+
if output:
|
| 789 |
+
results.append(output)
|
| 790 |
+
|
| 791 |
+
return results
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
__all__ = ["ImageGenerationService"]
|
app/services/grok/services/image_edit.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok image edit service.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import AsyncGenerator, AsyncIterable, List, Union, Any
|
| 12 |
+
|
| 13 |
+
import orjson
|
| 14 |
+
from curl_cffi.requests.errors import RequestsError
|
| 15 |
+
|
| 16 |
+
from app.core.config import get_config
|
| 17 |
+
from app.core.exceptions import (
|
| 18 |
+
AppException,
|
| 19 |
+
ErrorType,
|
| 20 |
+
UpstreamException,
|
| 21 |
+
StreamIdleTimeoutError,
|
| 22 |
+
)
|
| 23 |
+
from app.core.logger import logger
|
| 24 |
+
from app.services.grok.utils.process import (
|
| 25 |
+
BaseProcessor,
|
| 26 |
+
_with_idle_timeout,
|
| 27 |
+
_normalize_line,
|
| 28 |
+
_collect_images,
|
| 29 |
+
_is_http2_error,
|
| 30 |
+
)
|
| 31 |
+
from app.services.grok.utils.upload import UploadService
|
| 32 |
+
from app.services.grok.utils.retry import pick_token, rate_limited
|
| 33 |
+
from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content
|
| 34 |
+
from app.services.grok.services.chat import GrokChatService
|
| 35 |
+
from app.services.grok.services.video import VideoService
|
| 36 |
+
from app.services.grok.utils.stream import wrap_stream_with_usage
|
| 37 |
+
from app.services.token import EffortType
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ImageEditResult:
|
| 42 |
+
stream: bool
|
| 43 |
+
data: Union[AsyncGenerator[str, None], List[str]]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ImageEditService:
|
| 47 |
+
"""Image edit orchestration service."""
|
| 48 |
+
|
| 49 |
+
async def edit(
|
| 50 |
+
self,
|
| 51 |
+
*,
|
| 52 |
+
token_mgr: Any,
|
| 53 |
+
token: str,
|
| 54 |
+
model_info: Any,
|
| 55 |
+
prompt: str,
|
| 56 |
+
images: List[str],
|
| 57 |
+
n: int,
|
| 58 |
+
response_format: str,
|
| 59 |
+
stream: bool,
|
| 60 |
+
chat_format: bool = False,
|
| 61 |
+
) -> ImageEditResult:
|
| 62 |
+
if len(images) > 3:
|
| 63 |
+
logger.info(
|
| 64 |
+
"Image edit received %d references; using the most recent 3",
|
| 65 |
+
len(images),
|
| 66 |
+
)
|
| 67 |
+
images = images[-3:]
|
| 68 |
+
|
| 69 |
+
max_token_retries = int(get_config("retry.max_retry") or 3)
|
| 70 |
+
tried_tokens: set[str] = set()
|
| 71 |
+
last_error: Exception | None = None
|
| 72 |
+
|
| 73 |
+
for attempt in range(max_token_retries):
|
| 74 |
+
preferred = token if attempt == 0 else None
|
| 75 |
+
current_token = await pick_token(
|
| 76 |
+
token_mgr, model_info.model_id, tried_tokens, preferred=preferred
|
| 77 |
+
)
|
| 78 |
+
if not current_token:
|
| 79 |
+
if last_error:
|
| 80 |
+
raise last_error
|
| 81 |
+
raise AppException(
|
| 82 |
+
message="No available tokens. Please try again later.",
|
| 83 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 84 |
+
code="rate_limit_exceeded",
|
| 85 |
+
status_code=429,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
tried_tokens.add(current_token)
|
| 89 |
+
try:
|
| 90 |
+
image_urls = await self._upload_images(images, current_token)
|
| 91 |
+
parent_post_id = await self._get_parent_post_id(
|
| 92 |
+
current_token, image_urls
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
model_config_override = {
|
| 96 |
+
"modelMap": {
|
| 97 |
+
"imageEditModel": "imagine",
|
| 98 |
+
"imageEditModelConfig": {
|
| 99 |
+
"imageReferences": image_urls,
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
if parent_post_id:
|
| 104 |
+
model_config_override["modelMap"]["imageEditModelConfig"][
|
| 105 |
+
"parentPostId"
|
| 106 |
+
] = parent_post_id
|
| 107 |
+
|
| 108 |
+
tool_overrides = {"imageGen": True}
|
| 109 |
+
|
| 110 |
+
if stream:
|
| 111 |
+
response = await GrokChatService().chat(
|
| 112 |
+
token=current_token,
|
| 113 |
+
message=prompt,
|
| 114 |
+
model=model_info.grok_model,
|
| 115 |
+
mode=None,
|
| 116 |
+
stream=True,
|
| 117 |
+
tool_overrides=tool_overrides,
|
| 118 |
+
model_config_override=model_config_override,
|
| 119 |
+
)
|
| 120 |
+
processor = ImageStreamProcessor(
|
| 121 |
+
model_info.model_id,
|
| 122 |
+
current_token,
|
| 123 |
+
n=n,
|
| 124 |
+
response_format=response_format,
|
| 125 |
+
chat_format=chat_format,
|
| 126 |
+
)
|
| 127 |
+
return ImageEditResult(
|
| 128 |
+
stream=True,
|
| 129 |
+
data=wrap_stream_with_usage(
|
| 130 |
+
processor.process(response),
|
| 131 |
+
token_mgr,
|
| 132 |
+
current_token,
|
| 133 |
+
model_info.model_id,
|
| 134 |
+
),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
images_out = await self._collect_images(
|
| 138 |
+
token=current_token,
|
| 139 |
+
prompt=prompt,
|
| 140 |
+
model_info=model_info,
|
| 141 |
+
n=n,
|
| 142 |
+
response_format=response_format,
|
| 143 |
+
tool_overrides=tool_overrides,
|
| 144 |
+
model_config_override=model_config_override,
|
| 145 |
+
)
|
| 146 |
+
try:
|
| 147 |
+
effort = (
|
| 148 |
+
EffortType.HIGH
|
| 149 |
+
if (model_info and model_info.cost.value == "high")
|
| 150 |
+
else EffortType.LOW
|
| 151 |
+
)
|
| 152 |
+
await token_mgr.consume(current_token, effort)
|
| 153 |
+
logger.debug(
|
| 154 |
+
f"Image edit completed, recorded usage (effort={effort.value})"
|
| 155 |
+
)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.warning(f"Failed to record image edit usage: {e}")
|
| 158 |
+
return ImageEditResult(stream=False, data=images_out)
|
| 159 |
+
|
| 160 |
+
except UpstreamException as e:
|
| 161 |
+
last_error = e
|
| 162 |
+
if rate_limited(e):
|
| 163 |
+
await token_mgr.mark_rate_limited(current_token)
|
| 164 |
+
logger.warning(
|
| 165 |
+
f"Token {current_token[:10]}... rate limited (429), "
|
| 166 |
+
f"trying next token (attempt {attempt + 1}/{max_token_retries})"
|
| 167 |
+
)
|
| 168 |
+
continue
|
| 169 |
+
raise
|
| 170 |
+
|
| 171 |
+
if last_error:
|
| 172 |
+
raise last_error
|
| 173 |
+
raise AppException(
|
| 174 |
+
message="No available tokens. Please try again later.",
|
| 175 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 176 |
+
code="rate_limit_exceeded",
|
| 177 |
+
status_code=429,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
async def _upload_images(self, images: List[str], token: str) -> List[str]:
|
| 181 |
+
image_urls: List[str] = []
|
| 182 |
+
upload_service = UploadService()
|
| 183 |
+
try:
|
| 184 |
+
for image in images:
|
| 185 |
+
_, file_uri = await upload_service.upload_file(image, token)
|
| 186 |
+
if file_uri:
|
| 187 |
+
if file_uri.startswith("http"):
|
| 188 |
+
image_urls.append(file_uri)
|
| 189 |
+
else:
|
| 190 |
+
image_urls.append(
|
| 191 |
+
f"https://assets.grok.com/{file_uri.lstrip('/')}"
|
| 192 |
+
)
|
| 193 |
+
finally:
|
| 194 |
+
await upload_service.close()
|
| 195 |
+
|
| 196 |
+
if not image_urls:
|
| 197 |
+
raise AppException(
|
| 198 |
+
message="Image upload failed",
|
| 199 |
+
error_type=ErrorType.SERVER.value,
|
| 200 |
+
code="upload_failed",
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return image_urls
|
| 204 |
+
|
| 205 |
+
async def _get_parent_post_id(self, token: str, image_urls: List[str]) -> str:
|
| 206 |
+
parent_post_id = None
|
| 207 |
+
try:
|
| 208 |
+
media_service = VideoService()
|
| 209 |
+
parent_post_id = await media_service.create_image_post(token, image_urls[0])
|
| 210 |
+
logger.debug(f"Parent post ID: {parent_post_id}")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logger.warning(f"Create image post failed: {e}")
|
| 213 |
+
|
| 214 |
+
if parent_post_id:
|
| 215 |
+
return parent_post_id
|
| 216 |
+
|
| 217 |
+
for url in image_urls:
|
| 218 |
+
match = re.search(r"/generated/([a-f0-9-]+)/", url)
|
| 219 |
+
if match:
|
| 220 |
+
parent_post_id = match.group(1)
|
| 221 |
+
logger.debug(f"Parent post ID: {parent_post_id}")
|
| 222 |
+
break
|
| 223 |
+
match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url)
|
| 224 |
+
if match:
|
| 225 |
+
parent_post_id = match.group(1)
|
| 226 |
+
logger.debug(f"Parent post ID: {parent_post_id}")
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
return parent_post_id or ""
|
| 230 |
+
|
| 231 |
+
async def _collect_images(
|
| 232 |
+
self,
|
| 233 |
+
*,
|
| 234 |
+
token: str,
|
| 235 |
+
prompt: str,
|
| 236 |
+
model_info: Any,
|
| 237 |
+
n: int,
|
| 238 |
+
response_format: str,
|
| 239 |
+
tool_overrides: dict,
|
| 240 |
+
model_config_override: dict,
|
| 241 |
+
) -> List[str]:
|
| 242 |
+
calls_needed = (n + 1) // 2
|
| 243 |
+
|
| 244 |
+
async def _call_edit():
|
| 245 |
+
response = await GrokChatService().chat(
|
| 246 |
+
token=token,
|
| 247 |
+
message=prompt,
|
| 248 |
+
model=model_info.grok_model,
|
| 249 |
+
mode=None,
|
| 250 |
+
stream=True,
|
| 251 |
+
tool_overrides=tool_overrides,
|
| 252 |
+
model_config_override=model_config_override,
|
| 253 |
+
)
|
| 254 |
+
processor = ImageCollectProcessor(
|
| 255 |
+
model_info.model_id, token, response_format=response_format
|
| 256 |
+
)
|
| 257 |
+
return await processor.process(response)
|
| 258 |
+
|
| 259 |
+
last_error: Exception | None = None
|
| 260 |
+
rate_limit_error: Exception | None = None
|
| 261 |
+
|
| 262 |
+
if calls_needed == 1:
|
| 263 |
+
all_images = await _call_edit()
|
| 264 |
+
else:
|
| 265 |
+
tasks = [_call_edit() for _ in range(calls_needed)]
|
| 266 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 267 |
+
|
| 268 |
+
all_images: List[str] = []
|
| 269 |
+
for result in results:
|
| 270 |
+
if isinstance(result, Exception):
|
| 271 |
+
logger.error(f"Concurrent call failed: {result}")
|
| 272 |
+
last_error = result
|
| 273 |
+
if rate_limited(result):
|
| 274 |
+
rate_limit_error = result
|
| 275 |
+
elif isinstance(result, list):
|
| 276 |
+
all_images.extend(result)
|
| 277 |
+
|
| 278 |
+
if not all_images:
|
| 279 |
+
if rate_limit_error:
|
| 280 |
+
raise rate_limit_error
|
| 281 |
+
if last_error:
|
| 282 |
+
raise last_error
|
| 283 |
+
raise UpstreamException(
|
| 284 |
+
"Image edit returned no results", details={"error": "empty_result"}
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if len(all_images) >= n:
|
| 288 |
+
return all_images[:n]
|
| 289 |
+
|
| 290 |
+
selected_images = all_images.copy()
|
| 291 |
+
while len(selected_images) < n:
|
| 292 |
+
selected_images.append("error")
|
| 293 |
+
return selected_images
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class ImageStreamProcessor(BaseProcessor):
|
| 297 |
+
"""HTTP image stream processor."""
|
| 298 |
+
|
| 299 |
+
def __init__(
|
| 300 |
+
self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json", chat_format: bool = False
|
| 301 |
+
):
|
| 302 |
+
super().__init__(model, token)
|
| 303 |
+
self.partial_index = 0
|
| 304 |
+
self.n = n
|
| 305 |
+
self.target_index = 0 if n == 1 else None
|
| 306 |
+
self.response_format = response_format
|
| 307 |
+
self.chat_format = chat_format
|
| 308 |
+
self._id_generated = False
|
| 309 |
+
self._response_id = ""
|
| 310 |
+
if response_format == "url":
|
| 311 |
+
self.response_field = "url"
|
| 312 |
+
elif response_format == "base64":
|
| 313 |
+
self.response_field = "base64"
|
| 314 |
+
else:
|
| 315 |
+
self.response_field = "b64_json"
|
| 316 |
+
|
| 317 |
+
def _sse(self, event: str, data: dict) -> str:
|
| 318 |
+
"""Build SSE response."""
|
| 319 |
+
return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
|
| 320 |
+
|
| 321 |
+
async def process(
|
| 322 |
+
self, response: AsyncIterable[bytes]
|
| 323 |
+
) -> AsyncGenerator[str, None]:
|
| 324 |
+
"""Process stream response."""
|
| 325 |
+
final_images = []
|
| 326 |
+
emitted_chat_chunk = False
|
| 327 |
+
idle_timeout = get_config("image.stream_timeout")
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
async for line in _with_idle_timeout(response, idle_timeout, self.model):
|
| 331 |
+
line = _normalize_line(line)
|
| 332 |
+
if not line:
|
| 333 |
+
continue
|
| 334 |
+
try:
|
| 335 |
+
data = orjson.loads(line)
|
| 336 |
+
except orjson.JSONDecodeError:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
resp = data.get("result", {}).get("response", {})
|
| 340 |
+
|
| 341 |
+
# Image generation progress
|
| 342 |
+
if img := resp.get("streamingImageGenerationResponse"):
|
| 343 |
+
image_index = img.get("imageIndex", 0)
|
| 344 |
+
progress = img.get("progress", 0)
|
| 345 |
+
|
| 346 |
+
if self.n == 1 and image_index != self.target_index:
|
| 347 |
+
continue
|
| 348 |
+
|
| 349 |
+
out_index = 0 if self.n == 1 else image_index
|
| 350 |
+
|
| 351 |
+
if not self.chat_format:
|
| 352 |
+
yield self._sse(
|
| 353 |
+
"image_generation.partial_image",
|
| 354 |
+
{
|
| 355 |
+
"type": "image_generation.partial_image",
|
| 356 |
+
self.response_field: "",
|
| 357 |
+
"index": out_index,
|
| 358 |
+
"progress": progress,
|
| 359 |
+
},
|
| 360 |
+
)
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
# modelResponse
|
| 364 |
+
if mr := resp.get("modelResponse"):
|
| 365 |
+
if urls := _collect_images(mr):
|
| 366 |
+
for url in urls:
|
| 367 |
+
if self.response_format == "url":
|
| 368 |
+
processed = await self.process_url(url, "image")
|
| 369 |
+
if processed:
|
| 370 |
+
final_images.append(processed)
|
| 371 |
+
continue
|
| 372 |
+
try:
|
| 373 |
+
dl_service = self._get_dl()
|
| 374 |
+
base64_data = await dl_service.parse_b64(
|
| 375 |
+
url, self.token, "image"
|
| 376 |
+
)
|
| 377 |
+
if base64_data:
|
| 378 |
+
if "," in base64_data:
|
| 379 |
+
b64 = base64_data.split(",", 1)[1]
|
| 380 |
+
else:
|
| 381 |
+
b64 = base64_data
|
| 382 |
+
final_images.append(b64)
|
| 383 |
+
except Exception as e:
|
| 384 |
+
logger.warning(
|
| 385 |
+
f"Failed to convert image to base64, falling back to URL: {e}"
|
| 386 |
+
)
|
| 387 |
+
processed = await self.process_url(url, "image")
|
| 388 |
+
if processed:
|
| 389 |
+
final_images.append(processed)
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
+
for index, img_data in enumerate(final_images):
|
| 393 |
+
if self.n == 1:
|
| 394 |
+
if index != self.target_index:
|
| 395 |
+
continue
|
| 396 |
+
out_index = 0
|
| 397 |
+
else:
|
| 398 |
+
out_index = index
|
| 399 |
+
|
| 400 |
+
# Wrap in markdown format for chat
|
| 401 |
+
output = img_data
|
| 402 |
+
if self.chat_format and output:
|
| 403 |
+
output = wrap_image_content(output, self.response_format)
|
| 404 |
+
|
| 405 |
+
if not self._id_generated:
|
| 406 |
+
self._response_id = make_response_id()
|
| 407 |
+
self._id_generated = True
|
| 408 |
+
|
| 409 |
+
if self.chat_format:
|
| 410 |
+
# OpenAI ChatCompletion chunk format
|
| 411 |
+
emitted_chat_chunk = True
|
| 412 |
+
yield self._sse(
|
| 413 |
+
"chat.completion.chunk",
|
| 414 |
+
make_chat_chunk(
|
| 415 |
+
self._response_id,
|
| 416 |
+
self.model,
|
| 417 |
+
output,
|
| 418 |
+
index=out_index,
|
| 419 |
+
is_final=True,
|
| 420 |
+
),
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
# Original image_generation format
|
| 424 |
+
yield self._sse(
|
| 425 |
+
"image_generation.completed",
|
| 426 |
+
{
|
| 427 |
+
"type": "image_generation.completed",
|
| 428 |
+
self.response_field: img_data,
|
| 429 |
+
"index": out_index,
|
| 430 |
+
"usage": {
|
| 431 |
+
"total_tokens": 0,
|
| 432 |
+
"input_tokens": 0,
|
| 433 |
+
"output_tokens": 0,
|
| 434 |
+
"input_tokens_details": {
|
| 435 |
+
"text_tokens": 0,
|
| 436 |
+
"image_tokens": 0,
|
| 437 |
+
},
|
| 438 |
+
},
|
| 439 |
+
},
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if self.chat_format:
|
| 443 |
+
if not self._id_generated:
|
| 444 |
+
self._response_id = make_response_id()
|
| 445 |
+
self._id_generated = True
|
| 446 |
+
if not emitted_chat_chunk:
|
| 447 |
+
yield self._sse(
|
| 448 |
+
"chat.completion.chunk",
|
| 449 |
+
make_chat_chunk(
|
| 450 |
+
self._response_id,
|
| 451 |
+
self.model,
|
| 452 |
+
"",
|
| 453 |
+
index=0,
|
| 454 |
+
is_final=True,
|
| 455 |
+
),
|
| 456 |
+
)
|
| 457 |
+
yield "data: [DONE]\n\n"
|
| 458 |
+
except asyncio.CancelledError:
|
| 459 |
+
logger.debug("Image stream cancelled by client")
|
| 460 |
+
except StreamIdleTimeoutError as e:
|
| 461 |
+
raise UpstreamException(
|
| 462 |
+
message=f"Image stream idle timeout after {e.idle_seconds}s",
|
| 463 |
+
status_code=504,
|
| 464 |
+
details={
|
| 465 |
+
"error": str(e),
|
| 466 |
+
"type": "stream_idle_timeout",
|
| 467 |
+
"idle_seconds": e.idle_seconds,
|
| 468 |
+
},
|
| 469 |
+
)
|
| 470 |
+
except RequestsError as e:
|
| 471 |
+
if _is_http2_error(e):
|
| 472 |
+
logger.warning(f"HTTP/2 stream error in image: {e}")
|
| 473 |
+
raise UpstreamException(
|
| 474 |
+
message="Upstream connection closed unexpectedly",
|
| 475 |
+
status_code=502,
|
| 476 |
+
details={"error": str(e), "type": "http2_stream_error"},
|
| 477 |
+
)
|
| 478 |
+
logger.error(f"Image stream request error: {e}")
|
| 479 |
+
raise UpstreamException(
|
| 480 |
+
message=f"Upstream request failed: {e}",
|
| 481 |
+
status_code=502,
|
| 482 |
+
details={"error": str(e)},
|
| 483 |
+
)
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logger.error(
|
| 486 |
+
f"Image stream processing error: {e}",
|
| 487 |
+
extra={"error_type": type(e).__name__},
|
| 488 |
+
)
|
| 489 |
+
raise
|
| 490 |
+
finally:
|
| 491 |
+
await self.close()
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class ImageCollectProcessor(BaseProcessor):
|
| 495 |
+
"""HTTP image non-stream processor."""
|
| 496 |
+
|
| 497 |
+
def __init__(self, model: str, token: str = "", response_format: str = "b64_json"):
|
| 498 |
+
if response_format == "base64":
|
| 499 |
+
response_format = "b64_json"
|
| 500 |
+
super().__init__(model, token)
|
| 501 |
+
self.response_format = response_format
|
| 502 |
+
|
| 503 |
+
async def process(self, response: AsyncIterable[bytes]) -> List[str]:
|
| 504 |
+
"""Process and collect images."""
|
| 505 |
+
images = []
|
| 506 |
+
idle_timeout = get_config("image.stream_timeout")
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
async for line in _with_idle_timeout(response, idle_timeout, self.model):
|
| 510 |
+
line = _normalize_line(line)
|
| 511 |
+
if not line:
|
| 512 |
+
continue
|
| 513 |
+
try:
|
| 514 |
+
data = orjson.loads(line)
|
| 515 |
+
except orjson.JSONDecodeError:
|
| 516 |
+
continue
|
| 517 |
+
|
| 518 |
+
resp = data.get("result", {}).get("response", {})
|
| 519 |
+
|
| 520 |
+
if mr := resp.get("modelResponse"):
|
| 521 |
+
if urls := _collect_images(mr):
|
| 522 |
+
for url in urls:
|
| 523 |
+
if self.response_format == "url":
|
| 524 |
+
processed = await self.process_url(url, "image")
|
| 525 |
+
if processed:
|
| 526 |
+
images.append(processed)
|
| 527 |
+
continue
|
| 528 |
+
try:
|
| 529 |
+
dl_service = self._get_dl()
|
| 530 |
+
base64_data = await dl_service.parse_b64(
|
| 531 |
+
url, self.token, "image"
|
| 532 |
+
)
|
| 533 |
+
if base64_data:
|
| 534 |
+
if "," in base64_data:
|
| 535 |
+
b64 = base64_data.split(",", 1)[1]
|
| 536 |
+
else:
|
| 537 |
+
b64 = base64_data
|
| 538 |
+
images.append(b64)
|
| 539 |
+
except Exception as e:
|
| 540 |
+
logger.warning(
|
| 541 |
+
f"Failed to convert image to base64, falling back to URL: {e}"
|
| 542 |
+
)
|
| 543 |
+
processed = await self.process_url(url, "image")
|
| 544 |
+
if processed:
|
| 545 |
+
images.append(processed)
|
| 546 |
+
|
| 547 |
+
except asyncio.CancelledError:
|
| 548 |
+
logger.debug("Image collect cancelled by client")
|
| 549 |
+
except StreamIdleTimeoutError as e:
|
| 550 |
+
logger.warning(f"Image collect idle timeout: {e}")
|
| 551 |
+
except RequestsError as e:
|
| 552 |
+
if _is_http2_error(e):
|
| 553 |
+
logger.warning(f"HTTP/2 stream error in image collect: {e}")
|
| 554 |
+
else:
|
| 555 |
+
logger.error(f"Image collect request error: {e}")
|
| 556 |
+
except Exception as e:
|
| 557 |
+
logger.error(
|
| 558 |
+
f"Image collect processing error: {e}",
|
| 559 |
+
extra={"error_type": type(e).__name__},
|
| 560 |
+
)
|
| 561 |
+
finally:
|
| 562 |
+
await self.close()
|
| 563 |
+
|
| 564 |
+
return images
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
__all__ = ["ImageEditService", "ImageEditResult"]
|
app/services/grok/services/model.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok 模型管理服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Optional, Tuple, List
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
from app.core.exceptions import ValidationException
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Tier(str, Enum):
|
| 13 |
+
"""模型档位"""
|
| 14 |
+
|
| 15 |
+
BASIC = "basic"
|
| 16 |
+
SUPER = "super"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Cost(str, Enum):
|
| 20 |
+
"""计费类型"""
|
| 21 |
+
|
| 22 |
+
LOW = "low"
|
| 23 |
+
HIGH = "high"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelInfo(BaseModel):
|
| 27 |
+
"""模型信息"""
|
| 28 |
+
|
| 29 |
+
model_id: str
|
| 30 |
+
grok_model: str
|
| 31 |
+
model_mode: str
|
| 32 |
+
tier: Tier = Field(default=Tier.BASIC)
|
| 33 |
+
cost: Cost = Field(default=Cost.LOW)
|
| 34 |
+
display_name: str
|
| 35 |
+
description: str = ""
|
| 36 |
+
is_image: bool = False
|
| 37 |
+
is_image_edit: bool = False
|
| 38 |
+
is_video: bool = False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ModelService:
|
| 42 |
+
"""模型管理服务"""
|
| 43 |
+
|
| 44 |
+
MODELS = [
|
| 45 |
+
ModelInfo(
|
| 46 |
+
model_id="grok-3",
|
| 47 |
+
grok_model="grok-3",
|
| 48 |
+
model_mode="MODEL_MODE_GROK_3",
|
| 49 |
+
tier=Tier.BASIC,
|
| 50 |
+
cost=Cost.LOW,
|
| 51 |
+
display_name="GROK-3",
|
| 52 |
+
is_image=False,
|
| 53 |
+
is_image_edit=False,
|
| 54 |
+
is_video=False,
|
| 55 |
+
),
|
| 56 |
+
ModelInfo(
|
| 57 |
+
model_id="grok-3-mini",
|
| 58 |
+
grok_model="grok-3",
|
| 59 |
+
model_mode="MODEL_MODE_GROK_3_MINI_THINKING",
|
| 60 |
+
tier=Tier.BASIC,
|
| 61 |
+
cost=Cost.LOW,
|
| 62 |
+
display_name="GROK-3-MINI",
|
| 63 |
+
is_image=False,
|
| 64 |
+
is_image_edit=False,
|
| 65 |
+
is_video=False,
|
| 66 |
+
),
|
| 67 |
+
ModelInfo(
|
| 68 |
+
model_id="grok-3-thinking",
|
| 69 |
+
grok_model="grok-3",
|
| 70 |
+
model_mode="MODEL_MODE_GROK_3_THINKING",
|
| 71 |
+
tier=Tier.BASIC,
|
| 72 |
+
cost=Cost.LOW,
|
| 73 |
+
display_name="GROK-3-THINKING",
|
| 74 |
+
is_image=False,
|
| 75 |
+
is_image_edit=False,
|
| 76 |
+
is_video=False,
|
| 77 |
+
),
|
| 78 |
+
ModelInfo(
|
| 79 |
+
model_id="grok-4",
|
| 80 |
+
grok_model="grok-4",
|
| 81 |
+
model_mode="MODEL_MODE_GROK_4",
|
| 82 |
+
tier=Tier.BASIC,
|
| 83 |
+
cost=Cost.LOW,
|
| 84 |
+
display_name="GROK-4",
|
| 85 |
+
is_image=False,
|
| 86 |
+
is_image_edit=False,
|
| 87 |
+
is_video=False,
|
| 88 |
+
),
|
| 89 |
+
ModelInfo(
|
| 90 |
+
model_id="grok-4-mini",
|
| 91 |
+
grok_model="grok-4-mini",
|
| 92 |
+
model_mode="MODEL_MODE_GROK_4_MINI_THINKING",
|
| 93 |
+
tier=Tier.BASIC,
|
| 94 |
+
cost=Cost.LOW,
|
| 95 |
+
display_name="GROK-4-MINI",
|
| 96 |
+
is_image=False,
|
| 97 |
+
is_image_edit=False,
|
| 98 |
+
is_video=False,
|
| 99 |
+
),
|
| 100 |
+
ModelInfo(
|
| 101 |
+
model_id="grok-4-thinking",
|
| 102 |
+
grok_model="grok-4",
|
| 103 |
+
model_mode="MODEL_MODE_GROK_4_THINKING",
|
| 104 |
+
tier=Tier.BASIC,
|
| 105 |
+
cost=Cost.LOW,
|
| 106 |
+
display_name="GROK-4-THINKING",
|
| 107 |
+
is_image=False,
|
| 108 |
+
is_image_edit=False,
|
| 109 |
+
is_video=False,
|
| 110 |
+
),
|
| 111 |
+
ModelInfo(
|
| 112 |
+
model_id="grok-4-heavy",
|
| 113 |
+
grok_model="grok-4",
|
| 114 |
+
model_mode="MODEL_MODE_HEAVY",
|
| 115 |
+
tier=Tier.SUPER,
|
| 116 |
+
cost=Cost.HIGH,
|
| 117 |
+
display_name="GROK-4-HEAVY",
|
| 118 |
+
is_image=False,
|
| 119 |
+
is_image_edit=False,
|
| 120 |
+
is_video=False,
|
| 121 |
+
),
|
| 122 |
+
ModelInfo(
|
| 123 |
+
model_id="grok-4.1-mini",
|
| 124 |
+
grok_model="grok-4-1-thinking-1129",
|
| 125 |
+
model_mode="MODEL_MODE_GROK_4_1_MINI_THINKING",
|
| 126 |
+
tier=Tier.BASIC,
|
| 127 |
+
cost=Cost.LOW,
|
| 128 |
+
display_name="GROK-4.1-MINI",
|
| 129 |
+
is_image=False,
|
| 130 |
+
is_image_edit=False,
|
| 131 |
+
is_video=False,
|
| 132 |
+
),
|
| 133 |
+
ModelInfo(
|
| 134 |
+
model_id="grok-4.1-fast",
|
| 135 |
+
grok_model="grok-4-1-thinking-1129",
|
| 136 |
+
model_mode="MODEL_MODE_FAST",
|
| 137 |
+
tier=Tier.BASIC,
|
| 138 |
+
cost=Cost.LOW,
|
| 139 |
+
display_name="GROK-4.1-FAST",
|
| 140 |
+
is_image=False,
|
| 141 |
+
is_image_edit=False,
|
| 142 |
+
is_video=False,
|
| 143 |
+
),
|
| 144 |
+
ModelInfo(
|
| 145 |
+
model_id="grok-4.1-expert",
|
| 146 |
+
grok_model="grok-4-1-thinking-1129",
|
| 147 |
+
model_mode="MODEL_MODE_EXPERT",
|
| 148 |
+
tier=Tier.BASIC,
|
| 149 |
+
cost=Cost.HIGH,
|
| 150 |
+
display_name="GROK-4.1-EXPERT",
|
| 151 |
+
is_image=False,
|
| 152 |
+
is_image_edit=False,
|
| 153 |
+
is_video=False,
|
| 154 |
+
),
|
| 155 |
+
ModelInfo(
|
| 156 |
+
model_id="grok-4.1-thinking",
|
| 157 |
+
grok_model="grok-4-1-thinking-1129",
|
| 158 |
+
model_mode="MODEL_MODE_GROK_4_1_THINKING",
|
| 159 |
+
tier=Tier.BASIC,
|
| 160 |
+
cost=Cost.HIGH,
|
| 161 |
+
display_name="GROK-4.1-THINKING",
|
| 162 |
+
is_image=False,
|
| 163 |
+
is_image_edit=False,
|
| 164 |
+
is_video=False,
|
| 165 |
+
),
|
| 166 |
+
ModelInfo(
|
| 167 |
+
model_id="grok-4.20-beta",
|
| 168 |
+
grok_model="grok-420",
|
| 169 |
+
model_mode="MODEL_MODE_GROK_420",
|
| 170 |
+
tier=Tier.BASIC,
|
| 171 |
+
cost=Cost.LOW,
|
| 172 |
+
display_name="GROK-4.20-BETA",
|
| 173 |
+
is_image=False,
|
| 174 |
+
is_image_edit=False,
|
| 175 |
+
is_video=False,
|
| 176 |
+
),
|
| 177 |
+
ModelInfo(
|
| 178 |
+
model_id="grok-imagine-1.0-fast",
|
| 179 |
+
grok_model="grok-3",
|
| 180 |
+
model_mode="MODEL_MODE_FAST",
|
| 181 |
+
tier=Tier.BASIC,
|
| 182 |
+
cost=Cost.HIGH,
|
| 183 |
+
display_name="Grok Image Fast",
|
| 184 |
+
description="Imagine waterfall image generation model for chat completions",
|
| 185 |
+
is_image=True,
|
| 186 |
+
is_image_edit=False,
|
| 187 |
+
is_video=False,
|
| 188 |
+
),
|
| 189 |
+
ModelInfo(
|
| 190 |
+
model_id="grok-imagine-1.0",
|
| 191 |
+
grok_model="grok-3",
|
| 192 |
+
model_mode="MODEL_MODE_FAST",
|
| 193 |
+
tier=Tier.BASIC,
|
| 194 |
+
cost=Cost.HIGH,
|
| 195 |
+
display_name="Grok Image",
|
| 196 |
+
description="Image generation model",
|
| 197 |
+
is_image=True,
|
| 198 |
+
is_image_edit=False,
|
| 199 |
+
is_video=False,
|
| 200 |
+
),
|
| 201 |
+
ModelInfo(
|
| 202 |
+
model_id="grok-imagine-1.0-edit",
|
| 203 |
+
grok_model="imagine-image-edit",
|
| 204 |
+
model_mode="MODEL_MODE_FAST",
|
| 205 |
+
tier=Tier.BASIC,
|
| 206 |
+
cost=Cost.HIGH,
|
| 207 |
+
display_name="Grok Image Edit",
|
| 208 |
+
description="Image edit model",
|
| 209 |
+
is_image=False,
|
| 210 |
+
is_image_edit=True,
|
| 211 |
+
is_video=False,
|
| 212 |
+
),
|
| 213 |
+
ModelInfo(
|
| 214 |
+
model_id="grok-imagine-1.0-video",
|
| 215 |
+
grok_model="grok-3",
|
| 216 |
+
model_mode="MODEL_MODE_FAST",
|
| 217 |
+
tier=Tier.BASIC,
|
| 218 |
+
cost=Cost.HIGH,
|
| 219 |
+
display_name="Grok Video",
|
| 220 |
+
description="Video generation model",
|
| 221 |
+
is_image=False,
|
| 222 |
+
is_image_edit=False,
|
| 223 |
+
is_video=True,
|
| 224 |
+
),
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
_map = {m.model_id: m for m in MODELS}
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def get(cls, model_id: str) -> Optional[ModelInfo]:
|
| 231 |
+
"""获取模型信息"""
|
| 232 |
+
return cls._map.get(model_id)
|
| 233 |
+
|
| 234 |
+
@classmethod
|
| 235 |
+
def list(cls) -> list[ModelInfo]:
|
| 236 |
+
"""获取所有模型"""
|
| 237 |
+
return list(cls._map.values())
|
| 238 |
+
|
| 239 |
+
@classmethod
|
| 240 |
+
def valid(cls, model_id: str) -> bool:
|
| 241 |
+
"""模型是否有效"""
|
| 242 |
+
return model_id in cls._map
|
| 243 |
+
|
| 244 |
+
@classmethod
|
| 245 |
+
def to_grok(cls, model_id: str) -> Tuple[str, str]:
|
| 246 |
+
"""转换为 Grok 参数"""
|
| 247 |
+
model = cls.get(model_id)
|
| 248 |
+
if not model:
|
| 249 |
+
raise ValidationException(f"Invalid model ID: {model_id}")
|
| 250 |
+
return model.grok_model, model.model_mode
|
| 251 |
+
|
| 252 |
+
@classmethod
|
| 253 |
+
def pool_for_model(cls, model_id: str) -> str:
|
| 254 |
+
"""根据模型选择 Token 池"""
|
| 255 |
+
model = cls.get(model_id)
|
| 256 |
+
if model and model.tier == Tier.SUPER:
|
| 257 |
+
return "ssoSuper"
|
| 258 |
+
return "ssoBasic"
|
| 259 |
+
|
| 260 |
+
@classmethod
|
| 261 |
+
def pool_candidates_for_model(cls, model_id: str) -> List[str]:
|
| 262 |
+
"""按优先级返回可用 Token 池列表"""
|
| 263 |
+
model = cls.get(model_id)
|
| 264 |
+
if model and model.tier == Tier.SUPER:
|
| 265 |
+
return ["ssoSuper"]
|
| 266 |
+
# 基础模型优先使用 basic 池,缺失时可回退到 super 池
|
| 267 |
+
return ["ssoBasic", "ssoSuper"]
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
__all__ = ["ModelService"]
|
app/services/grok/services/responses.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Responses API bridge service (OpenAI-compatible).
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import Any, AsyncGenerator, Dict, List, Optional
|
| 8 |
+
|
| 9 |
+
import orjson
|
| 10 |
+
|
| 11 |
+
from app.services.grok.services.chat import ChatService
|
| 12 |
+
from app.services.grok.utils import process as proc_base
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_TOOL_OUTPUT_TYPES = {
|
| 16 |
+
"tool_output",
|
| 17 |
+
"function_call_output",
|
| 18 |
+
"tool_call_output",
|
| 19 |
+
"input_tool_output",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
_BUILTIN_TOOL_TYPES = {
|
| 23 |
+
"web_search",
|
| 24 |
+
"web_search_2025_08_26",
|
| 25 |
+
"file_search",
|
| 26 |
+
"code_interpreter",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _now_ts() -> int:
|
| 31 |
+
return int(time.time())
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _new_response_id() -> str:
|
| 35 |
+
return f"resp_{uuid.uuid4().hex[:24]}"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _new_message_id() -> str:
|
| 39 |
+
return f"msg_{uuid.uuid4().hex[:24]}"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _new_tool_call_id() -> str:
|
| 43 |
+
return f"call_{uuid.uuid4().hex[:24]}"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _new_function_call_id() -> str:
|
| 47 |
+
return f"fc_{uuid.uuid4().hex[:24]}"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _normalize_tool_choice(tool_choice: Any) -> Any:
|
| 51 |
+
if isinstance(tool_choice, dict):
|
| 52 |
+
t_type = tool_choice.get("type")
|
| 53 |
+
if t_type and t_type != "function":
|
| 54 |
+
return {"type": "function", "function": {"name": t_type}}
|
| 55 |
+
return tool_choice
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _normalize_tools_for_chat(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
|
| 59 |
+
if not tools:
|
| 60 |
+
return None
|
| 61 |
+
normalized: List[Dict[str, Any]] = []
|
| 62 |
+
for tool in tools:
|
| 63 |
+
if not isinstance(tool, dict):
|
| 64 |
+
continue
|
| 65 |
+
tool_type = tool.get("type")
|
| 66 |
+
if tool_type == "function":
|
| 67 |
+
normalized.append(tool)
|
| 68 |
+
continue
|
| 69 |
+
if tool_type in _BUILTIN_TOOL_TYPES:
|
| 70 |
+
if tool_type.startswith("web_search"):
|
| 71 |
+
normalized.append(
|
| 72 |
+
{
|
| 73 |
+
"type": "function",
|
| 74 |
+
"function": {
|
| 75 |
+
"name": tool_type,
|
| 76 |
+
"description": "Search the web for information and return results.",
|
| 77 |
+
"parameters": {
|
| 78 |
+
"type": "object",
|
| 79 |
+
"properties": {"query": {"type": "string"}},
|
| 80 |
+
"required": ["query"],
|
| 81 |
+
},
|
| 82 |
+
},
|
| 83 |
+
}
|
| 84 |
+
)
|
| 85 |
+
elif tool_type == "file_search":
|
| 86 |
+
normalized.append(
|
| 87 |
+
{
|
| 88 |
+
"type": "function",
|
| 89 |
+
"function": {
|
| 90 |
+
"name": tool_type,
|
| 91 |
+
"description": "Search provided files for relevant information.",
|
| 92 |
+
"parameters": {
|
| 93 |
+
"type": "object",
|
| 94 |
+
"properties": {"query": {"type": "string"}},
|
| 95 |
+
"required": ["query"],
|
| 96 |
+
},
|
| 97 |
+
},
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
elif tool_type == "code_interpreter":
|
| 101 |
+
normalized.append(
|
| 102 |
+
{
|
| 103 |
+
"type": "function",
|
| 104 |
+
"function": {
|
| 105 |
+
"name": tool_type,
|
| 106 |
+
"description": "Execute code to solve tasks and return results.",
|
| 107 |
+
"parameters": {
|
| 108 |
+
"type": "object",
|
| 109 |
+
"properties": {"code": {"type": "string"}},
|
| 110 |
+
"required": ["code"],
|
| 111 |
+
},
|
| 112 |
+
},
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
return normalized or None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _content_item_from_input(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 119 |
+
if not isinstance(item, dict):
|
| 120 |
+
return None
|
| 121 |
+
item_type = item.get("type")
|
| 122 |
+
|
| 123 |
+
if item_type in {"input_text", "text", "output_text"}:
|
| 124 |
+
text = item.get("text") or item.get("content") or ""
|
| 125 |
+
return {"type": "text", "text": text}
|
| 126 |
+
|
| 127 |
+
if item_type in {"input_image", "image", "image_url", "output_image"}:
|
| 128 |
+
image_url = item.get("image_url")
|
| 129 |
+
url = ""
|
| 130 |
+
detail = None
|
| 131 |
+
if isinstance(image_url, dict):
|
| 132 |
+
url = image_url.get("url") or ""
|
| 133 |
+
detail = image_url.get("detail")
|
| 134 |
+
elif isinstance(image_url, str):
|
| 135 |
+
url = image_url
|
| 136 |
+
else:
|
| 137 |
+
url = item.get("url") or item.get("image") or ""
|
| 138 |
+
|
| 139 |
+
if not url:
|
| 140 |
+
return None
|
| 141 |
+
image_payload = {"url": url}
|
| 142 |
+
if detail:
|
| 143 |
+
image_payload["detail"] = detail
|
| 144 |
+
return {"type": "image_url", "image_url": image_payload}
|
| 145 |
+
|
| 146 |
+
if item_type in {"input_file", "file"}:
|
| 147 |
+
file_data = item.get("file_data")
|
| 148 |
+
file_id = item.get("file_id")
|
| 149 |
+
if not file_data and isinstance(item.get("file"), dict):
|
| 150 |
+
file_data = item["file"].get("file_data")
|
| 151 |
+
file_id = item["file"].get("file_id")
|
| 152 |
+
file_payload: Dict[str, Any] = {}
|
| 153 |
+
if file_data:
|
| 154 |
+
file_payload["file_data"] = file_data
|
| 155 |
+
if file_id:
|
| 156 |
+
file_payload["file_id"] = file_id
|
| 157 |
+
if not file_payload:
|
| 158 |
+
return None
|
| 159 |
+
return {"type": "file", "file": file_payload}
|
| 160 |
+
|
| 161 |
+
if item_type in {"input_audio", "audio"}:
|
| 162 |
+
audio = item.get("audio") or {}
|
| 163 |
+
data = audio.get("data") or item.get("data")
|
| 164 |
+
if not data:
|
| 165 |
+
return None
|
| 166 |
+
return {"type": "input_audio", "input_audio": {"data": data}}
|
| 167 |
+
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _message_from_item(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 172 |
+
if not isinstance(item, dict):
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
if item.get("type") == "message":
|
| 176 |
+
role = item.get("role") or "user"
|
| 177 |
+
content = item.get("content")
|
| 178 |
+
return {"role": role, "content": _coerce_content(content)}
|
| 179 |
+
|
| 180 |
+
if "role" in item and "content" in item:
|
| 181 |
+
return {"role": item.get("role") or "user", "content": _coerce_content(item.get("content"))}
|
| 182 |
+
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _coerce_content(content: Any) -> Any:
|
| 187 |
+
if content is None:
|
| 188 |
+
return ""
|
| 189 |
+
if isinstance(content, str):
|
| 190 |
+
return content
|
| 191 |
+
if isinstance(content, dict):
|
| 192 |
+
content = [content]
|
| 193 |
+
if isinstance(content, list):
|
| 194 |
+
blocks: List[Dict[str, Any]] = []
|
| 195 |
+
for item in content:
|
| 196 |
+
if isinstance(item, dict) and item.get("type") in {"input_text", "output_text"}:
|
| 197 |
+
blocks.append({"type": "text", "text": item.get("text", "")})
|
| 198 |
+
continue
|
| 199 |
+
block = _content_item_from_input(item) if isinstance(item, dict) else None
|
| 200 |
+
if block:
|
| 201 |
+
blocks.append(block)
|
| 202 |
+
return blocks if blocks else ""
|
| 203 |
+
return str(content)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _coerce_input_to_messages(input_value: Any) -> List[Dict[str, Any]]:
|
| 207 |
+
if input_value is None:
|
| 208 |
+
return []
|
| 209 |
+
if isinstance(input_value, str):
|
| 210 |
+
return [{"role": "user", "content": input_value}]
|
| 211 |
+
|
| 212 |
+
if isinstance(input_value, dict):
|
| 213 |
+
msg = _message_from_item(input_value)
|
| 214 |
+
if msg:
|
| 215 |
+
return [msg]
|
| 216 |
+
content_item = _content_item_from_input(input_value)
|
| 217 |
+
if content_item:
|
| 218 |
+
return [{"role": "user", "content": [content_item]}]
|
| 219 |
+
return []
|
| 220 |
+
|
| 221 |
+
if not isinstance(input_value, list):
|
| 222 |
+
return [{"role": "user", "content": str(input_value)}]
|
| 223 |
+
|
| 224 |
+
messages: List[Dict[str, Any]] = []
|
| 225 |
+
pending_blocks: List[Dict[str, Any]] = []
|
| 226 |
+
|
| 227 |
+
def _flush_pending():
|
| 228 |
+
nonlocal pending_blocks
|
| 229 |
+
if pending_blocks:
|
| 230 |
+
messages.append({"role": "user", "content": pending_blocks})
|
| 231 |
+
pending_blocks = []
|
| 232 |
+
|
| 233 |
+
for item in input_value:
|
| 234 |
+
if isinstance(item, dict):
|
| 235 |
+
msg = _message_from_item(item)
|
| 236 |
+
if msg:
|
| 237 |
+
_flush_pending()
|
| 238 |
+
messages.append(msg)
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
item_type = item.get("type")
|
| 242 |
+
if item_type in _TOOL_OUTPUT_TYPES:
|
| 243 |
+
_flush_pending()
|
| 244 |
+
call_id = (
|
| 245 |
+
item.get("call_id")
|
| 246 |
+
or item.get("tool_call_id")
|
| 247 |
+
or item.get("id")
|
| 248 |
+
or _new_tool_call_id()
|
| 249 |
+
)
|
| 250 |
+
output = item.get("output") or item.get("content") or ""
|
| 251 |
+
messages.append({"role": "tool", "tool_call_id": call_id, "content": output})
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
block = _content_item_from_input(item)
|
| 255 |
+
if block:
|
| 256 |
+
pending_blocks.append(block)
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
if isinstance(item, str):
|
| 260 |
+
pending_blocks.append({"type": "text", "text": item})
|
| 261 |
+
|
| 262 |
+
_flush_pending()
|
| 263 |
+
return messages
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _build_output_message(
|
| 267 |
+
text: str,
|
| 268 |
+
*,
|
| 269 |
+
message_id: Optional[str] = None,
|
| 270 |
+
status: str = "completed",
|
| 271 |
+
) -> Dict[str, Any]:
|
| 272 |
+
message_id = message_id or _new_message_id()
|
| 273 |
+
return {
|
| 274 |
+
"id": message_id,
|
| 275 |
+
"type": "message",
|
| 276 |
+
"role": "assistant",
|
| 277 |
+
"status": status,
|
| 278 |
+
"content": [
|
| 279 |
+
{
|
| 280 |
+
"type": "output_text",
|
| 281 |
+
"text": text,
|
| 282 |
+
"annotations": [],
|
| 283 |
+
}
|
| 284 |
+
],
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _build_output_tool_call(
|
| 289 |
+
tool_call: Dict[str, Any],
|
| 290 |
+
*,
|
| 291 |
+
item_id: Optional[str] = None,
|
| 292 |
+
status: str = "completed",
|
| 293 |
+
) -> Dict[str, Any]:
|
| 294 |
+
fn = tool_call.get("function") or {}
|
| 295 |
+
call_id = tool_call.get("id") or _new_tool_call_id()
|
| 296 |
+
item_id = item_id or _new_function_call_id()
|
| 297 |
+
return {
|
| 298 |
+
"id": item_id,
|
| 299 |
+
"type": "function_call",
|
| 300 |
+
"status": status,
|
| 301 |
+
"call_id": call_id,
|
| 302 |
+
"name": fn.get("name"),
|
| 303 |
+
"arguments": fn.get("arguments"),
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _build_response_object(
|
| 308 |
+
*,
|
| 309 |
+
model: str,
|
| 310 |
+
output_text: Optional[str] = None,
|
| 311 |
+
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
| 312 |
+
response_id: Optional[str] = None,
|
| 313 |
+
usage: Optional[Dict[str, Any]] = None,
|
| 314 |
+
created_at: Optional[int] = None,
|
| 315 |
+
completed_at: Optional[int] = None,
|
| 316 |
+
status: str = "completed",
|
| 317 |
+
instructions: Optional[str] = None,
|
| 318 |
+
max_output_tokens: Optional[int] = None,
|
| 319 |
+
parallel_tool_calls: Optional[bool] = None,
|
| 320 |
+
previous_response_id: Optional[str] = None,
|
| 321 |
+
reasoning_effort: Optional[str] = None,
|
| 322 |
+
store: Optional[bool] = None,
|
| 323 |
+
temperature: Optional[float] = None,
|
| 324 |
+
tool_choice: Optional[Any] = None,
|
| 325 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
| 326 |
+
top_p: Optional[float] = None,
|
| 327 |
+
truncation: Optional[str] = None,
|
| 328 |
+
user: Optional[str] = None,
|
| 329 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 330 |
+
) -> Dict[str, Any]:
|
| 331 |
+
response_id = response_id or _new_response_id()
|
| 332 |
+
created_at = created_at or _now_ts()
|
| 333 |
+
if status == "completed" and completed_at is None:
|
| 334 |
+
completed_at = _now_ts()
|
| 335 |
+
|
| 336 |
+
output: List[Dict[str, Any]] = []
|
| 337 |
+
if output_text is not None:
|
| 338 |
+
output.append(_build_output_message(output_text))
|
| 339 |
+
|
| 340 |
+
if tool_calls:
|
| 341 |
+
for call in tool_calls:
|
| 342 |
+
output.append(_build_output_tool_call(call))
|
| 343 |
+
|
| 344 |
+
return {
|
| 345 |
+
"id": response_id,
|
| 346 |
+
"object": "response",
|
| 347 |
+
"created_at": created_at,
|
| 348 |
+
"completed_at": completed_at,
|
| 349 |
+
"status": status,
|
| 350 |
+
"error": None,
|
| 351 |
+
"incomplete_details": None,
|
| 352 |
+
"instructions": instructions,
|
| 353 |
+
"max_output_tokens": max_output_tokens,
|
| 354 |
+
"model": model,
|
| 355 |
+
"output": output,
|
| 356 |
+
"parallel_tool_calls": True if parallel_tool_calls is None else parallel_tool_calls,
|
| 357 |
+
"previous_response_id": previous_response_id,
|
| 358 |
+
"reasoning": {"effort": reasoning_effort, "summary": None},
|
| 359 |
+
"store": True if store is None else store,
|
| 360 |
+
"temperature": 1.0 if temperature is None else temperature,
|
| 361 |
+
"text": {"format": {"type": "text"}},
|
| 362 |
+
"tool_choice": tool_choice or "auto",
|
| 363 |
+
"tools": tools or [],
|
| 364 |
+
"top_p": 1.0 if top_p is None else top_p,
|
| 365 |
+
"truncation": truncation or "disabled",
|
| 366 |
+
"usage": usage,
|
| 367 |
+
"user": user,
|
| 368 |
+
"metadata": metadata or {},
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class ResponseStreamAdapter:
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
*,
|
| 376 |
+
model: str,
|
| 377 |
+
response_id: str,
|
| 378 |
+
created_at: int,
|
| 379 |
+
instructions: Optional[str],
|
| 380 |
+
max_output_tokens: Optional[int],
|
| 381 |
+
parallel_tool_calls: Optional[bool],
|
| 382 |
+
previous_response_id: Optional[str],
|
| 383 |
+
reasoning_effort: Optional[str],
|
| 384 |
+
store: Optional[bool],
|
| 385 |
+
temperature: Optional[float],
|
| 386 |
+
tool_choice: Optional[Any],
|
| 387 |
+
tools: Optional[List[Dict[str, Any]]],
|
| 388 |
+
top_p: Optional[float],
|
| 389 |
+
truncation: Optional[str],
|
| 390 |
+
user: Optional[str],
|
| 391 |
+
metadata: Optional[Dict[str, Any]],
|
| 392 |
+
):
|
| 393 |
+
self.model = model
|
| 394 |
+
self.response_id = response_id
|
| 395 |
+
self.created_at = created_at
|
| 396 |
+
self.instructions = instructions
|
| 397 |
+
self.max_output_tokens = max_output_tokens
|
| 398 |
+
self.parallel_tool_calls = parallel_tool_calls
|
| 399 |
+
self.previous_response_id = previous_response_id
|
| 400 |
+
self.reasoning_effort = reasoning_effort
|
| 401 |
+
self.store = store
|
| 402 |
+
self.temperature = temperature
|
| 403 |
+
self.tool_choice = tool_choice
|
| 404 |
+
self.tools = tools
|
| 405 |
+
self.top_p = top_p
|
| 406 |
+
self.truncation = truncation
|
| 407 |
+
self.user = user
|
| 408 |
+
self.metadata = metadata
|
| 409 |
+
|
| 410 |
+
self.output_text_parts: List[str] = []
|
| 411 |
+
self.tool_calls_by_index: Dict[int, Dict[str, Any]] = {}
|
| 412 |
+
self.tool_items: Dict[int, Dict[str, Any]] = {}
|
| 413 |
+
self.next_output_index = 0
|
| 414 |
+
self.content_index = 0
|
| 415 |
+
self.message_id = _new_message_id()
|
| 416 |
+
self.message_started = False
|
| 417 |
+
self.message_output_index: Optional[int] = None
|
| 418 |
+
|
| 419 |
+
def _event(self, event_type: str, payload: Dict[str, Any]) -> str:
|
| 420 |
+
return f"event: {event_type}\ndata: {orjson.dumps(payload).decode()}\n\n"
|
| 421 |
+
|
| 422 |
+
def _response_payload(self, *, status: str, output_text: Optional[str], usage: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
| 423 |
+
tool_calls = None
|
| 424 |
+
if status == "completed" and self.tool_calls_by_index:
|
| 425 |
+
tool_calls = [
|
| 426 |
+
self.tool_calls_by_index[idx]
|
| 427 |
+
for idx in sorted(self.tool_calls_by_index.keys())
|
| 428 |
+
]
|
| 429 |
+
return _build_response_object(
|
| 430 |
+
model=self.model,
|
| 431 |
+
output_text=output_text,
|
| 432 |
+
tool_calls=tool_calls,
|
| 433 |
+
response_id=self.response_id,
|
| 434 |
+
usage=usage,
|
| 435 |
+
created_at=self.created_at,
|
| 436 |
+
status=status,
|
| 437 |
+
instructions=self.instructions,
|
| 438 |
+
max_output_tokens=self.max_output_tokens,
|
| 439 |
+
parallel_tool_calls=self.parallel_tool_calls,
|
| 440 |
+
previous_response_id=self.previous_response_id,
|
| 441 |
+
reasoning_effort=self.reasoning_effort,
|
| 442 |
+
store=self.store,
|
| 443 |
+
temperature=self.temperature,
|
| 444 |
+
tool_choice=self.tool_choice,
|
| 445 |
+
tools=self.tools,
|
| 446 |
+
top_p=self.top_p,
|
| 447 |
+
truncation=self.truncation,
|
| 448 |
+
user=self.user,
|
| 449 |
+
metadata=self.metadata,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def _alloc_output_index(self) -> int:
|
| 453 |
+
idx = self.next_output_index
|
| 454 |
+
self.next_output_index += 1
|
| 455 |
+
return idx
|
| 456 |
+
|
| 457 |
+
def created_event(self) -> str:
|
| 458 |
+
payload = {
|
| 459 |
+
"type": "response.created",
|
| 460 |
+
"response": self._response_payload(status="in_progress", output_text=None, usage=None),
|
| 461 |
+
}
|
| 462 |
+
return self._event("response.created", payload)
|
| 463 |
+
|
| 464 |
+
def in_progress_event(self) -> str:
|
| 465 |
+
payload = {
|
| 466 |
+
"type": "response.in_progress",
|
| 467 |
+
"response": self._response_payload(status="in_progress", output_text=None, usage=None),
|
| 468 |
+
}
|
| 469 |
+
return self._event("response.in_progress", payload)
|
| 470 |
+
|
| 471 |
+
def ensure_message_started(self) -> List[str]:
|
| 472 |
+
if self.message_started:
|
| 473 |
+
return []
|
| 474 |
+
self.message_started = True
|
| 475 |
+
self.message_output_index = self._alloc_output_index()
|
| 476 |
+
item = _build_output_message("", message_id=self.message_id, status="in_progress")
|
| 477 |
+
item["content"] = []
|
| 478 |
+
events = [
|
| 479 |
+
self._event(
|
| 480 |
+
"response.output_item.added",
|
| 481 |
+
{
|
| 482 |
+
"type": "response.output_item.added",
|
| 483 |
+
"response_id": self.response_id,
|
| 484 |
+
"output_index": self.message_output_index,
|
| 485 |
+
"item": item,
|
| 486 |
+
},
|
| 487 |
+
),
|
| 488 |
+
self._event(
|
| 489 |
+
"response.content_part.added",
|
| 490 |
+
{
|
| 491 |
+
"type": "response.content_part.added",
|
| 492 |
+
"response_id": self.response_id,
|
| 493 |
+
"item_id": self.message_id,
|
| 494 |
+
"output_index": self.message_output_index,
|
| 495 |
+
"content_index": self.content_index,
|
| 496 |
+
"part": {"type": "output_text", "text": "", "annotations": []},
|
| 497 |
+
},
|
| 498 |
+
),
|
| 499 |
+
]
|
| 500 |
+
return events
|
| 501 |
+
|
| 502 |
+
def output_delta_event(self, delta: str) -> str:
|
| 503 |
+
return self._event(
|
| 504 |
+
"response.output_text.delta",
|
| 505 |
+
{
|
| 506 |
+
"type": "response.output_text.delta",
|
| 507 |
+
"response_id": self.response_id,
|
| 508 |
+
"item_id": self.message_id,
|
| 509 |
+
"output_index": self.message_output_index,
|
| 510 |
+
"content_index": self.content_index,
|
| 511 |
+
"delta": delta,
|
| 512 |
+
},
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
def output_done_events(self, text: str) -> List[str]:
|
| 516 |
+
if self.message_output_index is None:
|
| 517 |
+
return []
|
| 518 |
+
return [
|
| 519 |
+
self._event(
|
| 520 |
+
"response.output_text.done",
|
| 521 |
+
{
|
| 522 |
+
"type": "response.output_text.done",
|
| 523 |
+
"response_id": self.response_id,
|
| 524 |
+
"item_id": self.message_id,
|
| 525 |
+
"output_index": self.message_output_index,
|
| 526 |
+
"content_index": self.content_index,
|
| 527 |
+
"text": text,
|
| 528 |
+
},
|
| 529 |
+
),
|
| 530 |
+
self._event(
|
| 531 |
+
"response.content_part.done",
|
| 532 |
+
{
|
| 533 |
+
"type": "response.content_part.done",
|
| 534 |
+
"response_id": self.response_id,
|
| 535 |
+
"item_id": self.message_id,
|
| 536 |
+
"output_index": self.message_output_index,
|
| 537 |
+
"content_index": self.content_index,
|
| 538 |
+
"part": {"type": "output_text", "text": text, "annotations": []},
|
| 539 |
+
},
|
| 540 |
+
),
|
| 541 |
+
self._event(
|
| 542 |
+
"response.output_item.done",
|
| 543 |
+
{
|
| 544 |
+
"type": "response.output_item.done",
|
| 545 |
+
"response_id": self.response_id,
|
| 546 |
+
"output_index": self.message_output_index,
|
| 547 |
+
"item": _build_output_message(
|
| 548 |
+
text, message_id=self.message_id, status="completed"
|
| 549 |
+
),
|
| 550 |
+
},
|
| 551 |
+
),
|
| 552 |
+
]
|
| 553 |
+
|
| 554 |
+
def ensure_tool_item(self, tool_index: int, call_id: str, name: Optional[str]) -> List[str]:
|
| 555 |
+
if tool_index in self.tool_items:
|
| 556 |
+
item = self.tool_items[tool_index]
|
| 557 |
+
if name and not item.get("name"):
|
| 558 |
+
item["name"] = name
|
| 559 |
+
return []
|
| 560 |
+
output_index = self._alloc_output_index()
|
| 561 |
+
item_id = _new_function_call_id()
|
| 562 |
+
self.tool_items[tool_index] = {
|
| 563 |
+
"item_id": item_id,
|
| 564 |
+
"output_index": output_index,
|
| 565 |
+
"call_id": call_id,
|
| 566 |
+
"name": name,
|
| 567 |
+
"arguments": "",
|
| 568 |
+
}
|
| 569 |
+
tool_item = _build_output_tool_call(
|
| 570 |
+
{"id": call_id, "function": {"name": name, "arguments": ""}},
|
| 571 |
+
item_id=item_id,
|
| 572 |
+
status="in_progress",
|
| 573 |
+
)
|
| 574 |
+
return [
|
| 575 |
+
self._event(
|
| 576 |
+
"response.output_item.added",
|
| 577 |
+
{
|
| 578 |
+
"type": "response.output_item.added",
|
| 579 |
+
"response_id": self.response_id,
|
| 580 |
+
"output_index": output_index,
|
| 581 |
+
"item": tool_item,
|
| 582 |
+
},
|
| 583 |
+
)
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
def tool_arguments_delta_event(self, tool_index: int, delta: str) -> Optional[str]:
|
| 587 |
+
if not delta:
|
| 588 |
+
return None
|
| 589 |
+
item = self.tool_items.get(tool_index)
|
| 590 |
+
if not item:
|
| 591 |
+
return None
|
| 592 |
+
item["arguments"] += delta
|
| 593 |
+
return self._event(
|
| 594 |
+
"response.function_call_arguments.delta",
|
| 595 |
+
{
|
| 596 |
+
"type": "response.function_call_arguments.delta",
|
| 597 |
+
"response_id": self.response_id,
|
| 598 |
+
"item_id": item["item_id"],
|
| 599 |
+
"output_index": item["output_index"],
|
| 600 |
+
"delta": delta,
|
| 601 |
+
},
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
def tool_arguments_done_events(self) -> List[str]:
|
| 605 |
+
events: List[str] = []
|
| 606 |
+
for tool_index, item in sorted(
|
| 607 |
+
self.tool_items.items(), key=lambda kv: kv[1]["output_index"]
|
| 608 |
+
):
|
| 609 |
+
events.append(
|
| 610 |
+
self._event(
|
| 611 |
+
"response.function_call_arguments.done",
|
| 612 |
+
{
|
| 613 |
+
"type": "response.function_call_arguments.done",
|
| 614 |
+
"response_id": self.response_id,
|
| 615 |
+
"item_id": item["item_id"],
|
| 616 |
+
"output_index": item["output_index"],
|
| 617 |
+
"arguments": item["arguments"],
|
| 618 |
+
},
|
| 619 |
+
)
|
| 620 |
+
)
|
| 621 |
+
tool_item = _build_output_tool_call(
|
| 622 |
+
{
|
| 623 |
+
"id": item["call_id"],
|
| 624 |
+
"function": {"name": item.get("name"), "arguments": item["arguments"]},
|
| 625 |
+
},
|
| 626 |
+
item_id=item["item_id"],
|
| 627 |
+
status="completed",
|
| 628 |
+
)
|
| 629 |
+
events.append(
|
| 630 |
+
self._event(
|
| 631 |
+
"response.output_item.done",
|
| 632 |
+
{
|
| 633 |
+
"type": "response.output_item.done",
|
| 634 |
+
"response_id": self.response_id,
|
| 635 |
+
"output_index": item["output_index"],
|
| 636 |
+
"item": tool_item,
|
| 637 |
+
},
|
| 638 |
+
)
|
| 639 |
+
)
|
| 640 |
+
return events
|
| 641 |
+
|
| 642 |
+
def record_tool_call(self, tool_index: int, call_id: str, name: Optional[str], arguments_delta: str) -> None:
|
| 643 |
+
tool_call = self.tool_calls_by_index.get(tool_index)
|
| 644 |
+
if not tool_call:
|
| 645 |
+
tool_call = {
|
| 646 |
+
"id": call_id or _new_tool_call_id(),
|
| 647 |
+
"type": "function",
|
| 648 |
+
"function": {"name": name, "arguments": ""},
|
| 649 |
+
}
|
| 650 |
+
self.tool_calls_by_index[tool_index] = tool_call
|
| 651 |
+
if name and not tool_call["function"].get("name"):
|
| 652 |
+
tool_call["function"]["name"] = name
|
| 653 |
+
if arguments_delta:
|
| 654 |
+
tool_call["function"]["arguments"] += arguments_delta
|
| 655 |
+
|
| 656 |
+
def completed_event(self, usage: Optional[Dict[str, Any]] = None) -> str:
|
| 657 |
+
response = self._response_payload(
|
| 658 |
+
status="completed",
|
| 659 |
+
output_text="".join(self.output_text_parts) if self.message_started else None,
|
| 660 |
+
usage=usage
|
| 661 |
+
or {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
|
| 662 |
+
)
|
| 663 |
+
payload = {"type": "response.completed", "response": response}
|
| 664 |
+
return self._event("response.completed", payload)
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
class ResponsesService:
|
| 668 |
+
@staticmethod
|
| 669 |
+
async def create(
|
| 670 |
+
*,
|
| 671 |
+
model: str,
|
| 672 |
+
input_value: Any,
|
| 673 |
+
instructions: Optional[str] = None,
|
| 674 |
+
stream: bool = False,
|
| 675 |
+
temperature: Optional[float] = None,
|
| 676 |
+
top_p: Optional[float] = None,
|
| 677 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
| 678 |
+
tool_choice: Any = None,
|
| 679 |
+
parallel_tool_calls: Optional[bool] = None,
|
| 680 |
+
reasoning_effort: Optional[str] = None,
|
| 681 |
+
max_output_tokens: Optional[int] = None,
|
| 682 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 683 |
+
user: Optional[str] = None,
|
| 684 |
+
store: Optional[bool] = None,
|
| 685 |
+
previous_response_id: Optional[str] = None,
|
| 686 |
+
truncation: Optional[str] = None,
|
| 687 |
+
) -> Any:
|
| 688 |
+
messages = _coerce_input_to_messages(input_value)
|
| 689 |
+
if instructions:
|
| 690 |
+
messages = [{"role": "system", "content": instructions}] + messages
|
| 691 |
+
|
| 692 |
+
if not messages:
|
| 693 |
+
raise ValueError("input is required")
|
| 694 |
+
|
| 695 |
+
normalized_tools = _normalize_tools_for_chat(tools)
|
| 696 |
+
normalized_tool_choice = _normalize_tool_choice(tool_choice)
|
| 697 |
+
|
| 698 |
+
chat_kwargs: Dict[str, Any] = {
|
| 699 |
+
"model": model,
|
| 700 |
+
"messages": messages,
|
| 701 |
+
"stream": stream,
|
| 702 |
+
}
|
| 703 |
+
if temperature is not None:
|
| 704 |
+
chat_kwargs["temperature"] = temperature
|
| 705 |
+
if top_p is not None:
|
| 706 |
+
chat_kwargs["top_p"] = top_p
|
| 707 |
+
if normalized_tools is not None:
|
| 708 |
+
chat_kwargs["tools"] = normalized_tools
|
| 709 |
+
if normalized_tool_choice is not None:
|
| 710 |
+
chat_kwargs["tool_choice"] = normalized_tool_choice
|
| 711 |
+
if parallel_tool_calls is not None:
|
| 712 |
+
chat_kwargs["parallel_tool_calls"] = parallel_tool_calls
|
| 713 |
+
if reasoning_effort is not None:
|
| 714 |
+
chat_kwargs["reasoning_effort"] = reasoning_effort
|
| 715 |
+
|
| 716 |
+
result = await ChatService.completions(**chat_kwargs)
|
| 717 |
+
|
| 718 |
+
if not stream:
|
| 719 |
+
if not isinstance(result, dict):
|
| 720 |
+
raise ValueError("Unexpected stream response for non-stream request")
|
| 721 |
+
choice = (result.get("choices") or [{}])[0]
|
| 722 |
+
message = choice.get("message") or {}
|
| 723 |
+
content = message.get("content") or ""
|
| 724 |
+
tool_calls = message.get("tool_calls")
|
| 725 |
+
return _build_response_object(
|
| 726 |
+
model=model,
|
| 727 |
+
output_text=content,
|
| 728 |
+
tool_calls=tool_calls,
|
| 729 |
+
usage=result.get("usage")
|
| 730 |
+
or {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
|
| 731 |
+
status="completed",
|
| 732 |
+
instructions=instructions,
|
| 733 |
+
max_output_tokens=max_output_tokens,
|
| 734 |
+
parallel_tool_calls=parallel_tool_calls,
|
| 735 |
+
previous_response_id=previous_response_id,
|
| 736 |
+
reasoning_effort=reasoning_effort,
|
| 737 |
+
store=store,
|
| 738 |
+
temperature=temperature,
|
| 739 |
+
tool_choice=tool_choice,
|
| 740 |
+
tools=tools,
|
| 741 |
+
top_p=top_p,
|
| 742 |
+
truncation=truncation,
|
| 743 |
+
user=user,
|
| 744 |
+
metadata=metadata,
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
if not hasattr(result, "__aiter__"):
|
| 748 |
+
raise ValueError("Unexpected non-stream response for stream request")
|
| 749 |
+
|
| 750 |
+
created_at = _now_ts()
|
| 751 |
+
response_id = _new_response_id()
|
| 752 |
+
adapter = ResponseStreamAdapter(
|
| 753 |
+
model=model,
|
| 754 |
+
response_id=response_id,
|
| 755 |
+
created_at=created_at,
|
| 756 |
+
instructions=instructions,
|
| 757 |
+
max_output_tokens=max_output_tokens,
|
| 758 |
+
parallel_tool_calls=parallel_tool_calls,
|
| 759 |
+
previous_response_id=previous_response_id,
|
| 760 |
+
reasoning_effort=reasoning_effort,
|
| 761 |
+
store=store,
|
| 762 |
+
temperature=temperature,
|
| 763 |
+
tool_choice=tool_choice,
|
| 764 |
+
tools=tools,
|
| 765 |
+
top_p=top_p,
|
| 766 |
+
truncation=truncation,
|
| 767 |
+
user=user,
|
| 768 |
+
metadata=metadata,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
async def _stream() -> AsyncGenerator[str, None]:
|
| 772 |
+
yield adapter.created_event()
|
| 773 |
+
yield adapter.in_progress_event()
|
| 774 |
+
async for chunk in result:
|
| 775 |
+
line = proc_base._normalize_line(chunk)
|
| 776 |
+
if not line:
|
| 777 |
+
continue
|
| 778 |
+
try:
|
| 779 |
+
data = orjson.loads(line)
|
| 780 |
+
except orjson.JSONDecodeError:
|
| 781 |
+
continue
|
| 782 |
+
|
| 783 |
+
if data.get("object") == "chat.completion.chunk":
|
| 784 |
+
delta = (data.get("choices") or [{}])[0].get("delta") or {}
|
| 785 |
+
if "content" in delta and delta["content"]:
|
| 786 |
+
for event in adapter.ensure_message_started():
|
| 787 |
+
yield event
|
| 788 |
+
adapter.output_text_parts.append(delta["content"])
|
| 789 |
+
yield adapter.output_delta_event(delta["content"])
|
| 790 |
+
tool_calls = delta.get("tool_calls")
|
| 791 |
+
if isinstance(tool_calls, list):
|
| 792 |
+
for tool in tool_calls:
|
| 793 |
+
if not isinstance(tool, dict):
|
| 794 |
+
continue
|
| 795 |
+
tool_index = tool.get("index", 0)
|
| 796 |
+
call_id = tool.get("id") or _new_tool_call_id()
|
| 797 |
+
fn = tool.get("function") or {}
|
| 798 |
+
name = fn.get("name")
|
| 799 |
+
args_delta = fn.get("arguments") or ""
|
| 800 |
+
adapter.record_tool_call(
|
| 801 |
+
tool_index, call_id, name, args_delta
|
| 802 |
+
)
|
| 803 |
+
for event in adapter.ensure_tool_item(
|
| 804 |
+
tool_index, call_id, name
|
| 805 |
+
):
|
| 806 |
+
yield event
|
| 807 |
+
delta_event = adapter.tool_arguments_delta_event(
|
| 808 |
+
tool_index, args_delta
|
| 809 |
+
)
|
| 810 |
+
if delta_event:
|
| 811 |
+
yield delta_event
|
| 812 |
+
|
| 813 |
+
full_text = "".join(adapter.output_text_parts)
|
| 814 |
+
if full_text and adapter.message_started:
|
| 815 |
+
for event in adapter.output_done_events(full_text):
|
| 816 |
+
yield event
|
| 817 |
+
for event in adapter.tool_arguments_done_events():
|
| 818 |
+
yield event
|
| 819 |
+
yield adapter.completed_event()
|
| 820 |
+
|
| 821 |
+
return _stream()
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
__all__ = ["ResponsesService"]
|
app/services/grok/services/video.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok video generation service.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import uuid
|
| 7 |
+
import re
|
| 8 |
+
from typing import Any, AsyncGenerator, AsyncIterable, Optional
|
| 9 |
+
|
| 10 |
+
import orjson
|
| 11 |
+
from curl_cffi.requests.errors import RequestsError
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
from app.core.config import get_config
|
| 15 |
+
from app.core.exceptions import (
|
| 16 |
+
UpstreamException,
|
| 17 |
+
AppException,
|
| 18 |
+
ValidationException,
|
| 19 |
+
ErrorType,
|
| 20 |
+
StreamIdleTimeoutError,
|
| 21 |
+
)
|
| 22 |
+
from app.services.grok.services.model import ModelService
|
| 23 |
+
from app.services.token import get_token_manager, EffortType
|
| 24 |
+
from app.services.grok.utils.stream import wrap_stream_with_usage
|
| 25 |
+
from app.services.grok.utils.process import (
|
| 26 |
+
BaseProcessor,
|
| 27 |
+
_with_idle_timeout,
|
| 28 |
+
_normalize_line,
|
| 29 |
+
_is_http2_error,
|
| 30 |
+
)
|
| 31 |
+
from app.services.grok.utils.retry import rate_limited
|
| 32 |
+
from app.services.reverse.app_chat import AppChatReverse
|
| 33 |
+
from app.services.reverse.media_post import MediaPostReverse
|
| 34 |
+
from app.services.reverse.video_upscale import VideoUpscaleReverse
|
| 35 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 36 |
+
from app.services.token.manager import BASIC_POOL_NAME
|
| 37 |
+
|
| 38 |
+
_VIDEO_SEMAPHORE = None
|
| 39 |
+
_VIDEO_SEM_VALUE = 0
|
| 40 |
+
|
| 41 |
+
def _get_video_semaphore() -> asyncio.Semaphore:
|
| 42 |
+
"""Reverse 接口并发控制(video 服务)。"""
|
| 43 |
+
global _VIDEO_SEMAPHORE, _VIDEO_SEM_VALUE
|
| 44 |
+
value = max(1, int(get_config("video.concurrent")))
|
| 45 |
+
if value != _VIDEO_SEM_VALUE:
|
| 46 |
+
_VIDEO_SEM_VALUE = value
|
| 47 |
+
_VIDEO_SEMAPHORE = asyncio.Semaphore(value)
|
| 48 |
+
return _VIDEO_SEMAPHORE
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _new_session() -> ResettableSession:
|
| 52 |
+
browser = get_config("proxy.browser")
|
| 53 |
+
if browser:
|
| 54 |
+
return ResettableSession(impersonate=browser)
|
| 55 |
+
return ResettableSession()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class VideoService:
|
| 59 |
+
"""Video generation service."""
|
| 60 |
+
|
| 61 |
+
def __init__(self):
|
| 62 |
+
self.timeout = None
|
| 63 |
+
|
| 64 |
+
async def create_post(
|
| 65 |
+
self,
|
| 66 |
+
token: str,
|
| 67 |
+
prompt: str,
|
| 68 |
+
media_type: str = "MEDIA_POST_TYPE_VIDEO",
|
| 69 |
+
media_url: str = None,
|
| 70 |
+
) -> str:
|
| 71 |
+
"""Create media post and return post ID."""
|
| 72 |
+
try:
|
| 73 |
+
if media_type == "MEDIA_POST_TYPE_IMAGE" and not media_url:
|
| 74 |
+
raise ValidationException("media_url is required for image posts")
|
| 75 |
+
|
| 76 |
+
prompt_value = prompt if media_type == "MEDIA_POST_TYPE_VIDEO" else ""
|
| 77 |
+
media_value = media_url or ""
|
| 78 |
+
|
| 79 |
+
async with _new_session() as session:
|
| 80 |
+
async with _get_video_semaphore():
|
| 81 |
+
response = await MediaPostReverse.request(
|
| 82 |
+
session,
|
| 83 |
+
token,
|
| 84 |
+
media_type,
|
| 85 |
+
media_value,
|
| 86 |
+
prompt=prompt_value,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
post_id = response.json().get("post", {}).get("id", "")
|
| 90 |
+
if not post_id:
|
| 91 |
+
raise UpstreamException("No post ID in response")
|
| 92 |
+
|
| 93 |
+
logger.info(f"Media post created: {post_id} (type={media_type})")
|
| 94 |
+
return post_id
|
| 95 |
+
|
| 96 |
+
except AppException:
|
| 97 |
+
raise
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Create post error: {e}")
|
| 100 |
+
raise UpstreamException(f"Create post error: {str(e)}")
|
| 101 |
+
|
| 102 |
+
async def create_image_post(self, token: str, image_url: str) -> str:
|
| 103 |
+
"""Create image post and return post ID."""
|
| 104 |
+
return await self.create_post(
|
| 105 |
+
token, prompt="", media_type="MEDIA_POST_TYPE_IMAGE", media_url=image_url
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
async def generate(
|
| 109 |
+
self,
|
| 110 |
+
token: str,
|
| 111 |
+
prompt: str,
|
| 112 |
+
aspect_ratio: str = "3:2",
|
| 113 |
+
video_length: int = 6,
|
| 114 |
+
resolution_name: str = "480p",
|
| 115 |
+
preset: str = "normal",
|
| 116 |
+
) -> AsyncGenerator[bytes, None]:
|
| 117 |
+
"""Generate video."""
|
| 118 |
+
logger.info(
|
| 119 |
+
f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}"
|
| 120 |
+
)
|
| 121 |
+
post_id = await self.create_post(token, prompt)
|
| 122 |
+
mode_map = {
|
| 123 |
+
"fun": "--mode=extremely-crazy",
|
| 124 |
+
"normal": "--mode=normal",
|
| 125 |
+
"spicy": "--mode=extremely-spicy-or-crazy",
|
| 126 |
+
}
|
| 127 |
+
mode_flag = mode_map.get(preset, "--mode=custom")
|
| 128 |
+
message = f"{prompt} {mode_flag}"
|
| 129 |
+
model_config_override = {
|
| 130 |
+
"modelMap": {
|
| 131 |
+
"videoGenModelConfig": {
|
| 132 |
+
"aspectRatio": aspect_ratio,
|
| 133 |
+
"parentPostId": post_id,
|
| 134 |
+
"resolutionName": resolution_name,
|
| 135 |
+
"videoLength": video_length,
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
async def _stream():
|
| 141 |
+
session = _new_session()
|
| 142 |
+
try:
|
| 143 |
+
async with _get_video_semaphore():
|
| 144 |
+
stream_response = await AppChatReverse.request(
|
| 145 |
+
session,
|
| 146 |
+
token,
|
| 147 |
+
message=message,
|
| 148 |
+
model="grok-3",
|
| 149 |
+
tool_overrides={"videoGen": True},
|
| 150 |
+
model_config_override=model_config_override,
|
| 151 |
+
)
|
| 152 |
+
logger.info(f"Video generation started: post_id={post_id}")
|
| 153 |
+
async for line in stream_response:
|
| 154 |
+
yield line
|
| 155 |
+
except Exception as e:
|
| 156 |
+
try:
|
| 157 |
+
await session.close()
|
| 158 |
+
except Exception:
|
| 159 |
+
pass
|
| 160 |
+
logger.error(f"Video generation error: {e}")
|
| 161 |
+
if isinstance(e, AppException):
|
| 162 |
+
raise
|
| 163 |
+
raise UpstreamException(f"Video generation error: {str(e)}")
|
| 164 |
+
|
| 165 |
+
return _stream()
|
| 166 |
+
|
| 167 |
+
async def generate_from_image(
|
| 168 |
+
self,
|
| 169 |
+
token: str,
|
| 170 |
+
prompt: str,
|
| 171 |
+
image_url: str,
|
| 172 |
+
aspect_ratio: str = "3:2",
|
| 173 |
+
video_length: int = 6,
|
| 174 |
+
resolution: str = "480p",
|
| 175 |
+
preset: str = "normal",
|
| 176 |
+
) -> AsyncGenerator[bytes, None]:
|
| 177 |
+
"""Generate video from image."""
|
| 178 |
+
logger.info(
|
| 179 |
+
f"Image to video: prompt='{prompt[:50]}...', image={image_url[:80]}"
|
| 180 |
+
)
|
| 181 |
+
post_id = await self.create_image_post(token, image_url)
|
| 182 |
+
mode_map = {
|
| 183 |
+
"fun": "--mode=extremely-crazy",
|
| 184 |
+
"normal": "--mode=normal",
|
| 185 |
+
"spicy": "--mode=extremely-spicy-or-crazy",
|
| 186 |
+
}
|
| 187 |
+
mode_flag = mode_map.get(preset, "--mode=custom")
|
| 188 |
+
message = f"{prompt} {mode_flag}"
|
| 189 |
+
model_config_override = {
|
| 190 |
+
"modelMap": {
|
| 191 |
+
"videoGenModelConfig": {
|
| 192 |
+
"aspectRatio": aspect_ratio,
|
| 193 |
+
"parentPostId": post_id,
|
| 194 |
+
"resolutionName": resolution,
|
| 195 |
+
"videoLength": video_length,
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
async def _stream():
|
| 201 |
+
session = _new_session()
|
| 202 |
+
try:
|
| 203 |
+
async with _get_video_semaphore():
|
| 204 |
+
stream_response = await AppChatReverse.request(
|
| 205 |
+
session,
|
| 206 |
+
token,
|
| 207 |
+
message=message,
|
| 208 |
+
model="grok-3",
|
| 209 |
+
tool_overrides={"videoGen": True},
|
| 210 |
+
model_config_override=model_config_override,
|
| 211 |
+
)
|
| 212 |
+
logger.info(f"Video generation started: post_id={post_id}")
|
| 213 |
+
async for line in stream_response:
|
| 214 |
+
yield line
|
| 215 |
+
except Exception as e:
|
| 216 |
+
try:
|
| 217 |
+
await session.close()
|
| 218 |
+
except Exception:
|
| 219 |
+
pass
|
| 220 |
+
logger.error(f"Video generation error: {e}")
|
| 221 |
+
if isinstance(e, AppException):
|
| 222 |
+
raise
|
| 223 |
+
raise UpstreamException(f"Video generation error: {str(e)}")
|
| 224 |
+
|
| 225 |
+
return _stream()
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
async def completions(
|
| 229 |
+
model: str,
|
| 230 |
+
messages: list,
|
| 231 |
+
stream: bool = None,
|
| 232 |
+
reasoning_effort: str | None = None,
|
| 233 |
+
aspect_ratio: str = "3:2",
|
| 234 |
+
video_length: int = 6,
|
| 235 |
+
resolution: str = "480p",
|
| 236 |
+
preset: str = "normal",
|
| 237 |
+
):
|
| 238 |
+
"""Video generation entrypoint."""
|
| 239 |
+
# Get token via intelligent routing.
|
| 240 |
+
token_mgr = await get_token_manager()
|
| 241 |
+
await token_mgr.reload_if_stale()
|
| 242 |
+
|
| 243 |
+
max_token_retries = int(get_config("retry.max_retry"))
|
| 244 |
+
last_error: Exception | None = None
|
| 245 |
+
|
| 246 |
+
if reasoning_effort is None:
|
| 247 |
+
show_think = get_config("app.thinking")
|
| 248 |
+
else:
|
| 249 |
+
show_think = reasoning_effort != "none"
|
| 250 |
+
is_stream = stream if stream is not None else get_config("app.stream")
|
| 251 |
+
|
| 252 |
+
# Extract content.
|
| 253 |
+
from app.services.grok.services.chat import MessageExtractor
|
| 254 |
+
from app.services.grok.utils.upload import UploadService
|
| 255 |
+
|
| 256 |
+
prompt, file_attachments, image_attachments = MessageExtractor.extract(messages)
|
| 257 |
+
|
| 258 |
+
for attempt in range(max_token_retries):
|
| 259 |
+
# Select token based on video requirements and pool candidates.
|
| 260 |
+
pool_candidates = ModelService.pool_candidates_for_model(model)
|
| 261 |
+
token_info = token_mgr.get_token_for_video(
|
| 262 |
+
resolution=resolution,
|
| 263 |
+
video_length=video_length,
|
| 264 |
+
pool_candidates=pool_candidates,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if not token_info:
|
| 268 |
+
if last_error:
|
| 269 |
+
raise last_error
|
| 270 |
+
raise AppException(
|
| 271 |
+
message="No available tokens. Please try again later.",
|
| 272 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 273 |
+
code="rate_limit_exceeded",
|
| 274 |
+
status_code=429,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Extract token string from TokenInfo.
|
| 278 |
+
token = token_info.token
|
| 279 |
+
if token.startswith("sso="):
|
| 280 |
+
token = token[4:]
|
| 281 |
+
pool_name = token_mgr.get_pool_name_for_token(token)
|
| 282 |
+
should_upscale = resolution == "720p" and pool_name == BASIC_POOL_NAME
|
| 283 |
+
|
| 284 |
+
try:
|
| 285 |
+
# Handle image attachments.
|
| 286 |
+
image_url = None
|
| 287 |
+
if image_attachments:
|
| 288 |
+
upload_service = UploadService()
|
| 289 |
+
try:
|
| 290 |
+
if len(image_attachments) > 1:
|
| 291 |
+
logger.info(
|
| 292 |
+
"Video generation supports a single reference image; using the first one."
|
| 293 |
+
)
|
| 294 |
+
attach_data = image_attachments[0]
|
| 295 |
+
_, file_uri = await upload_service.upload_file(
|
| 296 |
+
attach_data, token
|
| 297 |
+
)
|
| 298 |
+
image_url = f"https://assets.grok.com/{file_uri}"
|
| 299 |
+
logger.info(f"Image uploaded for video: {image_url}")
|
| 300 |
+
finally:
|
| 301 |
+
await upload_service.close()
|
| 302 |
+
|
| 303 |
+
# Generate video.
|
| 304 |
+
service = VideoService()
|
| 305 |
+
if image_url:
|
| 306 |
+
response = await service.generate_from_image(
|
| 307 |
+
token,
|
| 308 |
+
prompt,
|
| 309 |
+
image_url,
|
| 310 |
+
aspect_ratio,
|
| 311 |
+
video_length,
|
| 312 |
+
resolution,
|
| 313 |
+
preset,
|
| 314 |
+
)
|
| 315 |
+
else:
|
| 316 |
+
response = await service.generate(
|
| 317 |
+
token,
|
| 318 |
+
prompt,
|
| 319 |
+
aspect_ratio,
|
| 320 |
+
video_length,
|
| 321 |
+
resolution,
|
| 322 |
+
preset,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Process response.
|
| 326 |
+
if is_stream:
|
| 327 |
+
processor = VideoStreamProcessor(
|
| 328 |
+
model,
|
| 329 |
+
token,
|
| 330 |
+
show_think,
|
| 331 |
+
upscale_on_finish=should_upscale,
|
| 332 |
+
)
|
| 333 |
+
return wrap_stream_with_usage(
|
| 334 |
+
processor.process(response), token_mgr, token, model
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
result = await VideoCollectProcessor(
|
| 338 |
+
model, token, upscale_on_finish=should_upscale
|
| 339 |
+
).process(response)
|
| 340 |
+
try:
|
| 341 |
+
model_info = ModelService.get(model)
|
| 342 |
+
effort = (
|
| 343 |
+
EffortType.HIGH
|
| 344 |
+
if (model_info and model_info.cost.value == "high")
|
| 345 |
+
else EffortType.LOW
|
| 346 |
+
)
|
| 347 |
+
await token_mgr.consume(token, effort)
|
| 348 |
+
logger.debug(
|
| 349 |
+
f"Video completed, recorded usage (effort={effort.value})"
|
| 350 |
+
)
|
| 351 |
+
except Exception as e:
|
| 352 |
+
logger.warning(f"Failed to record video usage: {e}")
|
| 353 |
+
return result
|
| 354 |
+
|
| 355 |
+
except UpstreamException as e:
|
| 356 |
+
last_error = e
|
| 357 |
+
if rate_limited(e):
|
| 358 |
+
await token_mgr.mark_rate_limited(token)
|
| 359 |
+
logger.warning(
|
| 360 |
+
f"Token {token[:10]}... rate limited (429), "
|
| 361 |
+
f"trying next token (attempt {attempt + 1}/{max_token_retries})"
|
| 362 |
+
)
|
| 363 |
+
continue
|
| 364 |
+
raise
|
| 365 |
+
|
| 366 |
+
if last_error:
|
| 367 |
+
raise last_error
|
| 368 |
+
raise AppException(
|
| 369 |
+
message="No available tokens. Please try again later.",
|
| 370 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 371 |
+
code="rate_limit_exceeded",
|
| 372 |
+
status_code=429,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class VideoStreamProcessor(BaseProcessor):
|
| 377 |
+
"""Video stream response processor."""
|
| 378 |
+
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
model: str,
|
| 382 |
+
token: str = "",
|
| 383 |
+
show_think: bool = None,
|
| 384 |
+
upscale_on_finish: bool = False,
|
| 385 |
+
):
|
| 386 |
+
super().__init__(model, token)
|
| 387 |
+
self.response_id: Optional[str] = None
|
| 388 |
+
self.think_opened: bool = False
|
| 389 |
+
self.role_sent: bool = False
|
| 390 |
+
|
| 391 |
+
self.show_think = bool(show_think)
|
| 392 |
+
self.upscale_on_finish = bool(upscale_on_finish)
|
| 393 |
+
|
| 394 |
+
@staticmethod
|
| 395 |
+
def _extract_video_id(video_url: str) -> str:
|
| 396 |
+
if not video_url:
|
| 397 |
+
return ""
|
| 398 |
+
match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url)
|
| 399 |
+
if match:
|
| 400 |
+
return match.group(1)
|
| 401 |
+
match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url)
|
| 402 |
+
if match:
|
| 403 |
+
return match.group(1)
|
| 404 |
+
return ""
|
| 405 |
+
|
| 406 |
+
async def _upscale_video_url(self, video_url: str) -> str:
|
| 407 |
+
if not video_url or not self.upscale_on_finish:
|
| 408 |
+
return video_url
|
| 409 |
+
video_id = self._extract_video_id(video_url)
|
| 410 |
+
if not video_id:
|
| 411 |
+
logger.warning("Video upscale skipped: unable to extract video id")
|
| 412 |
+
return video_url
|
| 413 |
+
try:
|
| 414 |
+
async with _new_session() as session:
|
| 415 |
+
response = await VideoUpscaleReverse.request(
|
| 416 |
+
session, self.token, video_id
|
| 417 |
+
)
|
| 418 |
+
payload = response.json() if response is not None else {}
|
| 419 |
+
hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None
|
| 420 |
+
if hd_url:
|
| 421 |
+
logger.info(f"Video upscale completed: {hd_url}")
|
| 422 |
+
return hd_url
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.warning(f"Video upscale failed: {e}")
|
| 425 |
+
return video_url
|
| 426 |
+
|
| 427 |
+
def _sse(self, content: str = "", role: str = None, finish: str = None) -> str:
|
| 428 |
+
"""Build SSE response."""
|
| 429 |
+
delta = {}
|
| 430 |
+
if role:
|
| 431 |
+
delta["role"] = role
|
| 432 |
+
delta["content"] = ""
|
| 433 |
+
elif content:
|
| 434 |
+
delta["content"] = content
|
| 435 |
+
|
| 436 |
+
chunk = {
|
| 437 |
+
"id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
|
| 438 |
+
"object": "chat.completion.chunk",
|
| 439 |
+
"created": self.created,
|
| 440 |
+
"model": self.model,
|
| 441 |
+
"choices": [
|
| 442 |
+
{"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}
|
| 443 |
+
],
|
| 444 |
+
}
|
| 445 |
+
return f"data: {orjson.dumps(chunk).decode()}\n\n"
|
| 446 |
+
|
| 447 |
+
async def process(
|
| 448 |
+
self, response: AsyncIterable[bytes]
|
| 449 |
+
) -> AsyncGenerator[str, None]:
|
| 450 |
+
"""Process video stream response."""
|
| 451 |
+
idle_timeout = get_config("video.stream_timeout")
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
async for line in _with_idle_timeout(response, idle_timeout, self.model):
|
| 455 |
+
line = _normalize_line(line)
|
| 456 |
+
if not line:
|
| 457 |
+
continue
|
| 458 |
+
try:
|
| 459 |
+
data = orjson.loads(line)
|
| 460 |
+
except orjson.JSONDecodeError:
|
| 461 |
+
continue
|
| 462 |
+
|
| 463 |
+
resp = data.get("result", {}).get("response", {})
|
| 464 |
+
is_thinking = bool(resp.get("isThinking"))
|
| 465 |
+
|
| 466 |
+
if rid := resp.get("responseId"):
|
| 467 |
+
self.response_id = rid
|
| 468 |
+
|
| 469 |
+
if not self.role_sent:
|
| 470 |
+
yield self._sse(role="assistant")
|
| 471 |
+
self.role_sent = True
|
| 472 |
+
|
| 473 |
+
if token := resp.get("token"):
|
| 474 |
+
if is_thinking:
|
| 475 |
+
if not self.show_think:
|
| 476 |
+
continue
|
| 477 |
+
if not self.think_opened:
|
| 478 |
+
yield self._sse("<think>\n")
|
| 479 |
+
self.think_opened = True
|
| 480 |
+
else:
|
| 481 |
+
if self.think_opened:
|
| 482 |
+
yield self._sse("\n</think>\n")
|
| 483 |
+
self.think_opened = False
|
| 484 |
+
yield self._sse(token)
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
if video_resp := resp.get("streamingVideoGenerationResponse"):
|
| 488 |
+
progress = video_resp.get("progress", 0)
|
| 489 |
+
|
| 490 |
+
if is_thinking:
|
| 491 |
+
if not self.show_think:
|
| 492 |
+
continue
|
| 493 |
+
if not self.think_opened:
|
| 494 |
+
yield self._sse("<think>\n")
|
| 495 |
+
self.think_opened = True
|
| 496 |
+
else:
|
| 497 |
+
if self.think_opened:
|
| 498 |
+
yield self._sse("\n</think>\n")
|
| 499 |
+
self.think_opened = False
|
| 500 |
+
if self.show_think:
|
| 501 |
+
yield self._sse(f"正在生成视频中,当前进度{progress}%\n")
|
| 502 |
+
|
| 503 |
+
if progress == 100:
|
| 504 |
+
video_url = video_resp.get("videoUrl", "")
|
| 505 |
+
thumbnail_url = video_resp.get("thumbnailImageUrl", "")
|
| 506 |
+
|
| 507 |
+
if self.think_opened:
|
| 508 |
+
yield self._sse("\n</think>\n")
|
| 509 |
+
self.think_opened = False
|
| 510 |
+
|
| 511 |
+
if video_url:
|
| 512 |
+
if self.upscale_on_finish:
|
| 513 |
+
yield self._sse("正在对视频进行超分辨率\n")
|
| 514 |
+
video_url = await self._upscale_video_url(video_url)
|
| 515 |
+
dl_service = self._get_dl()
|
| 516 |
+
rendered = await dl_service.render_video(
|
| 517 |
+
video_url, self.token, thumbnail_url
|
| 518 |
+
)
|
| 519 |
+
yield self._sse(rendered)
|
| 520 |
+
|
| 521 |
+
logger.info(f"Video generated: {video_url}")
|
| 522 |
+
continue
|
| 523 |
+
|
| 524 |
+
if self.think_opened:
|
| 525 |
+
yield self._sse("</think>\n")
|
| 526 |
+
yield self._sse(finish="stop")
|
| 527 |
+
yield "data: [DONE]\n\n"
|
| 528 |
+
except asyncio.CancelledError:
|
| 529 |
+
logger.debug(
|
| 530 |
+
"Video stream cancelled by client", extra={"model": self.model}
|
| 531 |
+
)
|
| 532 |
+
except StreamIdleTimeoutError as e:
|
| 533 |
+
raise UpstreamException(
|
| 534 |
+
message=f"Video stream idle timeout after {e.idle_seconds}s",
|
| 535 |
+
status_code=504,
|
| 536 |
+
details={
|
| 537 |
+
"error": str(e),
|
| 538 |
+
"type": "stream_idle_timeout",
|
| 539 |
+
"idle_seconds": e.idle_seconds,
|
| 540 |
+
},
|
| 541 |
+
)
|
| 542 |
+
except RequestsError as e:
|
| 543 |
+
if _is_http2_error(e):
|
| 544 |
+
logger.warning(
|
| 545 |
+
f"HTTP/2 stream error in video: {e}", extra={"model": self.model}
|
| 546 |
+
)
|
| 547 |
+
raise UpstreamException(
|
| 548 |
+
message="Upstream connection closed unexpectedly",
|
| 549 |
+
status_code=502,
|
| 550 |
+
details={"error": str(e), "type": "http2_stream_error"},
|
| 551 |
+
)
|
| 552 |
+
logger.error(
|
| 553 |
+
f"Video stream request error: {e}", extra={"model": self.model}
|
| 554 |
+
)
|
| 555 |
+
raise UpstreamException(
|
| 556 |
+
message=f"Upstream request failed: {e}",
|
| 557 |
+
status_code=502,
|
| 558 |
+
details={"error": str(e)},
|
| 559 |
+
)
|
| 560 |
+
except Exception as e:
|
| 561 |
+
logger.error(
|
| 562 |
+
f"Video stream processing error: {e}",
|
| 563 |
+
extra={"model": self.model, "error_type": type(e).__name__},
|
| 564 |
+
)
|
| 565 |
+
finally:
|
| 566 |
+
await self.close()
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class VideoCollectProcessor(BaseProcessor):
|
| 570 |
+
"""Video non-stream response processor."""
|
| 571 |
+
|
| 572 |
+
def __init__(self, model: str, token: str = "", upscale_on_finish: bool = False):
|
| 573 |
+
super().__init__(model, token)
|
| 574 |
+
self.upscale_on_finish = bool(upscale_on_finish)
|
| 575 |
+
|
| 576 |
+
@staticmethod
|
| 577 |
+
def _extract_video_id(video_url: str) -> str:
|
| 578 |
+
if not video_url:
|
| 579 |
+
return ""
|
| 580 |
+
match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url)
|
| 581 |
+
if match:
|
| 582 |
+
return match.group(1)
|
| 583 |
+
match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url)
|
| 584 |
+
if match:
|
| 585 |
+
return match.group(1)
|
| 586 |
+
return ""
|
| 587 |
+
|
| 588 |
+
async def _upscale_video_url(self, video_url: str) -> str:
|
| 589 |
+
if not video_url or not self.upscale_on_finish:
|
| 590 |
+
return video_url
|
| 591 |
+
video_id = self._extract_video_id(video_url)
|
| 592 |
+
if not video_id:
|
| 593 |
+
logger.warning("Video upscale skipped: unable to extract video id")
|
| 594 |
+
return video_url
|
| 595 |
+
try:
|
| 596 |
+
async with _new_session() as session:
|
| 597 |
+
response = await VideoUpscaleReverse.request(
|
| 598 |
+
session, self.token, video_id
|
| 599 |
+
)
|
| 600 |
+
payload = response.json() if response is not None else {}
|
| 601 |
+
hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None
|
| 602 |
+
if hd_url:
|
| 603 |
+
logger.info(f"Video upscale completed: {hd_url}")
|
| 604 |
+
return hd_url
|
| 605 |
+
except Exception as e:
|
| 606 |
+
logger.warning(f"Video upscale failed: {e}")
|
| 607 |
+
return video_url
|
| 608 |
+
|
| 609 |
+
async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
|
| 610 |
+
"""Process and collect video response."""
|
| 611 |
+
response_id = ""
|
| 612 |
+
content = ""
|
| 613 |
+
idle_timeout = get_config("video.stream_timeout")
|
| 614 |
+
|
| 615 |
+
try:
|
| 616 |
+
async for line in _with_idle_timeout(response, idle_timeout, self.model):
|
| 617 |
+
line = _normalize_line(line)
|
| 618 |
+
if not line:
|
| 619 |
+
continue
|
| 620 |
+
try:
|
| 621 |
+
data = orjson.loads(line)
|
| 622 |
+
except orjson.JSONDecodeError:
|
| 623 |
+
continue
|
| 624 |
+
|
| 625 |
+
resp = data.get("result", {}).get("response", {})
|
| 626 |
+
|
| 627 |
+
if video_resp := resp.get("streamingVideoGenerationResponse"):
|
| 628 |
+
if video_resp.get("progress") == 100:
|
| 629 |
+
response_id = resp.get("responseId", "")
|
| 630 |
+
video_url = video_resp.get("videoUrl", "")
|
| 631 |
+
thumbnail_url = video_resp.get("thumbnailImageUrl", "")
|
| 632 |
+
|
| 633 |
+
if video_url:
|
| 634 |
+
if self.upscale_on_finish:
|
| 635 |
+
video_url = await self._upscale_video_url(video_url)
|
| 636 |
+
dl_service = self._get_dl()
|
| 637 |
+
content = await dl_service.render_video(
|
| 638 |
+
video_url, self.token, thumbnail_url
|
| 639 |
+
)
|
| 640 |
+
logger.info(f"Video generated: {video_url}")
|
| 641 |
+
|
| 642 |
+
except asyncio.CancelledError:
|
| 643 |
+
logger.debug(
|
| 644 |
+
"Video collect cancelled by client", extra={"model": self.model}
|
| 645 |
+
)
|
| 646 |
+
except StreamIdleTimeoutError as e:
|
| 647 |
+
logger.warning(
|
| 648 |
+
f"Video collect idle timeout: {e}", extra={"model": self.model}
|
| 649 |
+
)
|
| 650 |
+
except RequestsError as e:
|
| 651 |
+
if _is_http2_error(e):
|
| 652 |
+
logger.warning(
|
| 653 |
+
f"HTTP/2 stream error in video collect: {e}",
|
| 654 |
+
extra={"model": self.model},
|
| 655 |
+
)
|
| 656 |
+
else:
|
| 657 |
+
logger.error(
|
| 658 |
+
f"Video collect request error: {e}", extra={"model": self.model}
|
| 659 |
+
)
|
| 660 |
+
except Exception as e:
|
| 661 |
+
logger.error(
|
| 662 |
+
f"Video collect processing error: {e}",
|
| 663 |
+
extra={"model": self.model, "error_type": type(e).__name__},
|
| 664 |
+
)
|
| 665 |
+
finally:
|
| 666 |
+
await self.close()
|
| 667 |
+
|
| 668 |
+
return {
|
| 669 |
+
"id": response_id,
|
| 670 |
+
"object": "chat.completion",
|
| 671 |
+
"created": self.created,
|
| 672 |
+
"model": self.model,
|
| 673 |
+
"choices": [
|
| 674 |
+
{
|
| 675 |
+
"index": 0,
|
| 676 |
+
"message": {
|
| 677 |
+
"role": "assistant",
|
| 678 |
+
"content": content,
|
| 679 |
+
"refusal": None,
|
| 680 |
+
},
|
| 681 |
+
"finish_reason": "stop",
|
| 682 |
+
}
|
| 683 |
+
],
|
| 684 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
__all__ = ["VideoService"]
|
app/services/grok/services/voice.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok Voice Mode Service
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from app.core.config import get_config
|
| 8 |
+
from app.services.reverse.ws_livekit import LivekitTokenReverse
|
| 9 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VoiceService:
|
| 13 |
+
"""Voice Mode Service (LiveKit)"""
|
| 14 |
+
|
| 15 |
+
async def get_token(
|
| 16 |
+
self,
|
| 17 |
+
token: str,
|
| 18 |
+
voice: str = "ara",
|
| 19 |
+
personality: str = "assistant",
|
| 20 |
+
speed: float = 1.0,
|
| 21 |
+
) -> Dict[str, Any]:
|
| 22 |
+
browser = get_config("proxy.browser")
|
| 23 |
+
async with ResettableSession(impersonate=browser) as session:
|
| 24 |
+
response = await LivekitTokenReverse.request(
|
| 25 |
+
session,
|
| 26 |
+
token=token,
|
| 27 |
+
voice=voice,
|
| 28 |
+
personality=personality,
|
| 29 |
+
speed=speed,
|
| 30 |
+
)
|
| 31 |
+
return response.json()
|
app/services/grok/utils/cache.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Local cache utilities.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from app.core.storage import DATA_DIR
|
| 8 |
+
|
| 9 |
+
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
| 10 |
+
VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CacheService:
|
| 14 |
+
"""Local cache service."""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
base_dir = DATA_DIR / "tmp"
|
| 18 |
+
self.image_dir = base_dir / "image"
|
| 19 |
+
self.video_dir = base_dir / "video"
|
| 20 |
+
self.image_dir.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
self.video_dir.mkdir(parents=True, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
def _cache_dir(self, media_type: str):
|
| 24 |
+
return self.image_dir if media_type == "image" else self.video_dir
|
| 25 |
+
|
| 26 |
+
def _allowed_exts(self, media_type: str):
|
| 27 |
+
return IMAGE_EXTS if media_type == "image" else VIDEO_EXTS
|
| 28 |
+
|
| 29 |
+
def get_stats(self, media_type: str = "image") -> Dict[str, Any]:
|
| 30 |
+
cache_dir = self._cache_dir(media_type)
|
| 31 |
+
if not cache_dir.exists():
|
| 32 |
+
return {"count": 0, "size_mb": 0.0}
|
| 33 |
+
|
| 34 |
+
allowed = self._allowed_exts(media_type)
|
| 35 |
+
files = [
|
| 36 |
+
f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed
|
| 37 |
+
]
|
| 38 |
+
total_size = sum(f.stat().st_size for f in files)
|
| 39 |
+
return {"count": len(files), "size_mb": round(total_size / 1024 / 1024, 2)}
|
| 40 |
+
|
| 41 |
+
def list_files(
|
| 42 |
+
self, media_type: str = "image", page: int = 1, page_size: int = 1000
|
| 43 |
+
) -> Dict[str, Any]:
|
| 44 |
+
cache_dir = self._cache_dir(media_type)
|
| 45 |
+
if not cache_dir.exists():
|
| 46 |
+
return {"total": 0, "page": page, "page_size": page_size, "items": []}
|
| 47 |
+
|
| 48 |
+
allowed = self._allowed_exts(media_type)
|
| 49 |
+
files = [
|
| 50 |
+
f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
items = []
|
| 54 |
+
for f in files:
|
| 55 |
+
try:
|
| 56 |
+
stat = f.stat()
|
| 57 |
+
items.append(
|
| 58 |
+
{
|
| 59 |
+
"name": f.name,
|
| 60 |
+
"size_bytes": stat.st_size,
|
| 61 |
+
"mtime_ms": int(stat.st_mtime * 1000),
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
except Exception:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
items.sort(key=lambda x: x["mtime_ms"], reverse=True)
|
| 68 |
+
|
| 69 |
+
total = len(items)
|
| 70 |
+
start = max(0, (page - 1) * page_size)
|
| 71 |
+
paged = items[start : start + page_size]
|
| 72 |
+
|
| 73 |
+
for item in paged:
|
| 74 |
+
item["view_url"] = f"/v1/files/{media_type}/{item['name']}"
|
| 75 |
+
|
| 76 |
+
return {"total": total, "page": page, "page_size": page_size, "items": paged}
|
| 77 |
+
|
| 78 |
+
def delete_file(self, media_type: str, name: str) -> Dict[str, Any]:
|
| 79 |
+
cache_dir = self._cache_dir(media_type)
|
| 80 |
+
file_path = cache_dir / name.replace("/", "-")
|
| 81 |
+
|
| 82 |
+
if file_path.exists():
|
| 83 |
+
try:
|
| 84 |
+
file_path.unlink()
|
| 85 |
+
return {"deleted": True}
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
return {"deleted": False}
|
| 89 |
+
|
| 90 |
+
def clear(self, media_type: str = "image") -> Dict[str, Any]:
|
| 91 |
+
cache_dir = self._cache_dir(media_type)
|
| 92 |
+
if not cache_dir.exists():
|
| 93 |
+
return {"count": 0, "size_mb": 0.0}
|
| 94 |
+
|
| 95 |
+
files = list(cache_dir.glob("*"))
|
| 96 |
+
total_size = sum(f.stat().st_size for f in files if f.is_file())
|
| 97 |
+
count = 0
|
| 98 |
+
|
| 99 |
+
for f in files:
|
| 100 |
+
if f.is_file():
|
| 101 |
+
try:
|
| 102 |
+
f.unlink()
|
| 103 |
+
count += 1
|
| 104 |
+
except Exception:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
return {"count": count, "size_mb": round(total_size / 1024 / 1024, 2)}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
__all__ = ["CacheService"]
|
app/services/grok/utils/download.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Download service.
|
| 3 |
+
|
| 4 |
+
Download service for assets.grok.com.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import base64
|
| 9 |
+
import hashlib
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List, Optional, Tuple
|
| 13 |
+
from urllib.parse import urlparse
|
| 14 |
+
|
| 15 |
+
import aiofiles
|
| 16 |
+
|
| 17 |
+
from app.core.logger import logger
|
| 18 |
+
from app.core.storage import DATA_DIR
|
| 19 |
+
from app.core.config import get_config
|
| 20 |
+
from app.core.exceptions import AppException
|
| 21 |
+
from app.services.reverse.assets_download import AssetsDownloadReverse
|
| 22 |
+
from app.services.reverse.utils.session import ResettableSession
|
| 23 |
+
from app.services.grok.utils.locks import _get_download_semaphore, _file_lock
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DownloadService:
|
| 27 |
+
"""Assets download service."""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self._session: Optional[ResettableSession] = None
|
| 31 |
+
base_dir = DATA_DIR / "tmp"
|
| 32 |
+
self.image_dir = base_dir / "image"
|
| 33 |
+
self.video_dir = base_dir / "video"
|
| 34 |
+
self.image_dir.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
self.video_dir.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
self._cleanup_running = False
|
| 37 |
+
|
| 38 |
+
async def create(self) -> ResettableSession:
|
| 39 |
+
"""Create or reuse a session."""
|
| 40 |
+
if self._session is None:
|
| 41 |
+
browser = get_config("proxy.browser")
|
| 42 |
+
if browser:
|
| 43 |
+
self._session = ResettableSession(impersonate=browser)
|
| 44 |
+
else:
|
| 45 |
+
self._session = ResettableSession()
|
| 46 |
+
return self._session
|
| 47 |
+
|
| 48 |
+
async def close(self):
|
| 49 |
+
"""Close the session."""
|
| 50 |
+
if self._session:
|
| 51 |
+
await self._session.close()
|
| 52 |
+
self._session = None
|
| 53 |
+
|
| 54 |
+
async def resolve_url(
|
| 55 |
+
self, path_or_url: str, token: str, media_type: str = "image"
|
| 56 |
+
) -> str:
|
| 57 |
+
asset_url = path_or_url
|
| 58 |
+
path = path_or_url
|
| 59 |
+
if path_or_url.startswith("http"):
|
| 60 |
+
parsed = urlparse(path_or_url)
|
| 61 |
+
path = parsed.path or ""
|
| 62 |
+
asset_url = path_or_url
|
| 63 |
+
else:
|
| 64 |
+
if not path_or_url.startswith("/"):
|
| 65 |
+
path_or_url = f"/{path_or_url}"
|
| 66 |
+
path = path_or_url
|
| 67 |
+
asset_url = f"https://assets.grok.com{path_or_url}"
|
| 68 |
+
|
| 69 |
+
app_url = get_config("app.app_url")
|
| 70 |
+
if app_url:
|
| 71 |
+
await self.download_file(asset_url, token, media_type)
|
| 72 |
+
return f"{app_url.rstrip('/')}/v1/files/{media_type}{path}"
|
| 73 |
+
return asset_url
|
| 74 |
+
|
| 75 |
+
async def render_image(
|
| 76 |
+
self, url: str, token: str, image_id: str = "image"
|
| 77 |
+
) -> str:
|
| 78 |
+
fmt = get_config("app.image_format")
|
| 79 |
+
fmt = fmt.lower() if isinstance(fmt, str) else "url"
|
| 80 |
+
if fmt not in ("base64", "url", "markdown"):
|
| 81 |
+
fmt = "url"
|
| 82 |
+
try:
|
| 83 |
+
if fmt == "base64":
|
| 84 |
+
data_uri = await self.parse_b64(url, token, "image")
|
| 85 |
+
return f""
|
| 86 |
+
final_url = await self.resolve_url(url, token, "image")
|
| 87 |
+
return f""
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.warning(f"Image render failed, fallback to URL: {e}")
|
| 90 |
+
final_url = await self.resolve_url(url, token, "image")
|
| 91 |
+
return f""
|
| 92 |
+
|
| 93 |
+
async def render_video(
|
| 94 |
+
self, video_url: str, token: str, thumbnail_url: str = ""
|
| 95 |
+
) -> str:
|
| 96 |
+
fmt = get_config("app.video_format")
|
| 97 |
+
fmt = fmt.lower() if isinstance(fmt, str) else "url"
|
| 98 |
+
if fmt not in ("url", "markdown", "html"):
|
| 99 |
+
fmt = "url"
|
| 100 |
+
final_video_url = await self.resolve_url(video_url, token, "video")
|
| 101 |
+
final_thumb_url = ""
|
| 102 |
+
if thumbnail_url:
|
| 103 |
+
final_thumb_url = await self.resolve_url(thumbnail_url, token, "image")
|
| 104 |
+
if fmt == "url":
|
| 105 |
+
return f"{final_video_url}\n"
|
| 106 |
+
if fmt == "markdown":
|
| 107 |
+
return f"[video]({final_video_url})"
|
| 108 |
+
import html
|
| 109 |
+
|
| 110 |
+
safe_video_url = html.escape(final_video_url)
|
| 111 |
+
safe_thumbnail_url = html.escape(final_thumb_url)
|
| 112 |
+
poster_attr = f' poster="{safe_thumbnail_url}"' if safe_thumbnail_url else ""
|
| 113 |
+
return f'''<video id="video" controls="" preload="none"{poster_attr}>
|
| 114 |
+
<source id="mp4" src="{safe_video_url}" type="video/mp4">
|
| 115 |
+
</video>'''
|
| 116 |
+
|
| 117 |
+
async def parse_b64(self, file_path: str, token: str, media_type: str = "image") -> str:
|
| 118 |
+
"""Download and return data URI."""
|
| 119 |
+
try:
|
| 120 |
+
if not isinstance(file_path, str) or not file_path.strip():
|
| 121 |
+
raise AppException("Invalid file path", code="invalid_file_path")
|
| 122 |
+
if file_path.startswith("data:"):
|
| 123 |
+
raise AppException("Invalid file path", code="invalid_file_path")
|
| 124 |
+
file_path = self._normalize_path(file_path)
|
| 125 |
+
lock_name = f"dl_b64_{hashlib.sha1(file_path.encode()).hexdigest()[:16]}"
|
| 126 |
+
lock_timeout = max(1, int(get_config("asset.download_timeout")))
|
| 127 |
+
async with _get_download_semaphore():
|
| 128 |
+
async with _file_lock(lock_name, timeout=lock_timeout):
|
| 129 |
+
session = await self.create()
|
| 130 |
+
response = await AssetsDownloadReverse.request(
|
| 131 |
+
session, token, file_path
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if hasattr(response, "aiter_content"):
|
| 135 |
+
data = bytearray()
|
| 136 |
+
async for chunk in response.aiter_content():
|
| 137 |
+
if chunk:
|
| 138 |
+
data.extend(chunk)
|
| 139 |
+
raw = bytes(data)
|
| 140 |
+
else:
|
| 141 |
+
raw = response.content
|
| 142 |
+
|
| 143 |
+
content_type = response.headers.get(
|
| 144 |
+
"content-type", "application/octet-stream"
|
| 145 |
+
).split(";")[0]
|
| 146 |
+
data_uri = f"data:{content_type};base64,{base64.b64encode(raw).decode()}"
|
| 147 |
+
|
| 148 |
+
return data_uri
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"Failed to convert {file_path} to base64: {e}")
|
| 151 |
+
raise
|
| 152 |
+
|
| 153 |
+
def _normalize_path(self, file_path: str) -> str:
|
| 154 |
+
"""Normalize URL or path to assets path for download."""
|
| 155 |
+
if not isinstance(file_path, str) or not file_path.strip():
|
| 156 |
+
raise AppException("Invalid file path", code="invalid_file_path")
|
| 157 |
+
|
| 158 |
+
value = file_path.strip()
|
| 159 |
+
if value.startswith("data:"):
|
| 160 |
+
raise AppException("Invalid file path", code="invalid_file_path")
|
| 161 |
+
|
| 162 |
+
parsed = urlparse(value)
|
| 163 |
+
if parsed.scheme or parsed.netloc:
|
| 164 |
+
if not (
|
| 165 |
+
parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]
|
| 166 |
+
):
|
| 167 |
+
raise AppException("Invalid file path", code="invalid_file_path")
|
| 168 |
+
path = parsed.path or ""
|
| 169 |
+
if parsed.query:
|
| 170 |
+
path = f"{path}?{parsed.query}"
|
| 171 |
+
else:
|
| 172 |
+
path = value
|
| 173 |
+
|
| 174 |
+
if not path:
|
| 175 |
+
raise AppException("Invalid file path", code="invalid_file_path")
|
| 176 |
+
if not path.startswith("/"):
|
| 177 |
+
path = f"/{path}"
|
| 178 |
+
|
| 179 |
+
return path
|
| 180 |
+
|
| 181 |
+
async def download_file(self, file_path: str, token: str, media_type: str = "image") -> Tuple[Optional[Path], str]:
|
| 182 |
+
"""Download asset to local cache.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
file_path: str, the path of the file to download.
|
| 186 |
+
token: str, the SSO token.
|
| 187 |
+
media_type: str, the media type of the file.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Tuple[Optional[Path], str]: The path of the downloaded file and the MIME type.
|
| 191 |
+
"""
|
| 192 |
+
async with _get_download_semaphore():
|
| 193 |
+
file_path = self._normalize_path(file_path)
|
| 194 |
+
cache_dir = self.image_dir if media_type == "image" else self.video_dir
|
| 195 |
+
filename = file_path.lstrip("/").replace("/", "-")
|
| 196 |
+
cache_path = cache_dir / filename
|
| 197 |
+
|
| 198 |
+
lock_name = (
|
| 199 |
+
f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}"
|
| 200 |
+
)
|
| 201 |
+
lock_timeout = max(1, int(get_config("asset.download_timeout")))
|
| 202 |
+
async with _file_lock(lock_name, timeout=lock_timeout):
|
| 203 |
+
session = await self.create()
|
| 204 |
+
response = await AssetsDownloadReverse.request(session, token, file_path)
|
| 205 |
+
|
| 206 |
+
tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp")
|
| 207 |
+
try:
|
| 208 |
+
async with aiofiles.open(tmp_path, "wb") as f:
|
| 209 |
+
if hasattr(response, "aiter_content"):
|
| 210 |
+
async for chunk in response.aiter_content():
|
| 211 |
+
if chunk:
|
| 212 |
+
await f.write(chunk)
|
| 213 |
+
else:
|
| 214 |
+
await f.write(response.content)
|
| 215 |
+
os.replace(tmp_path, cache_path)
|
| 216 |
+
finally:
|
| 217 |
+
if tmp_path.exists() and not cache_path.exists():
|
| 218 |
+
try:
|
| 219 |
+
tmp_path.unlink()
|
| 220 |
+
except Exception:
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
mime = response.headers.get(
|
| 224 |
+
"content-type", "application/octet-stream"
|
| 225 |
+
).split(";")[0]
|
| 226 |
+
logger.info(f"Downloaded: {file_path}")
|
| 227 |
+
|
| 228 |
+
asyncio.create_task(self._check_limit())
|
| 229 |
+
|
| 230 |
+
return cache_path, mime
|
| 231 |
+
|
| 232 |
+
async def _check_limit(self):
|
| 233 |
+
"""Check cache limit and cleanup.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
self: DownloadService, the download service instance.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
None
|
| 240 |
+
"""
|
| 241 |
+
if self._cleanup_running or not get_config("cache.enable_auto_clean"):
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
self._cleanup_running = True
|
| 245 |
+
try:
|
| 246 |
+
try:
|
| 247 |
+
async with _file_lock("cache_cleanup", timeout=5):
|
| 248 |
+
limit_mb = get_config("cache.limit_mb")
|
| 249 |
+
total_size = 0
|
| 250 |
+
all_files: List[Tuple[Path, float, int]] = []
|
| 251 |
+
|
| 252 |
+
for d in [self.image_dir, self.video_dir]:
|
| 253 |
+
if d.exists():
|
| 254 |
+
for f in d.glob("*"):
|
| 255 |
+
if f.is_file():
|
| 256 |
+
try:
|
| 257 |
+
stat = f.stat()
|
| 258 |
+
total_size += stat.st_size
|
| 259 |
+
all_files.append(
|
| 260 |
+
(f, stat.st_mtime, stat.st_size)
|
| 261 |
+
)
|
| 262 |
+
except Exception:
|
| 263 |
+
pass
|
| 264 |
+
current_mb = total_size / 1024 / 1024
|
| 265 |
+
|
| 266 |
+
if current_mb <= limit_mb:
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
logger.info(
|
| 270 |
+
f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..."
|
| 271 |
+
)
|
| 272 |
+
all_files.sort(key=lambda x: x[1])
|
| 273 |
+
|
| 274 |
+
deleted_count = 0
|
| 275 |
+
deleted_size = 0
|
| 276 |
+
target_mb = limit_mb * 0.8
|
| 277 |
+
|
| 278 |
+
for f, _, size in all_files:
|
| 279 |
+
try:
|
| 280 |
+
f.unlink()
|
| 281 |
+
deleted_count += 1
|
| 282 |
+
deleted_size += size
|
| 283 |
+
total_size -= size
|
| 284 |
+
if (total_size / 1024 / 1024) <= target_mb:
|
| 285 |
+
break
|
| 286 |
+
except Exception:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
logger.info(
|
| 290 |
+
f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)"
|
| 291 |
+
)
|
| 292 |
+
except Exception as e:
|
| 293 |
+
logger.warning(f"Cache cleanup failed: {e}")
|
| 294 |
+
finally:
|
| 295 |
+
self._cleanup_running = False
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
__all__ = ["DownloadService"]
|
app/services/grok/utils/locks.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared locking helpers for assets operations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from app.core.config import get_config
|
| 11 |
+
from app.core.storage import DATA_DIR
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import fcntl
|
| 15 |
+
except ImportError:
|
| 16 |
+
fcntl = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
LOCK_DIR = DATA_DIR / ".locks"
|
| 20 |
+
|
| 21 |
+
_UPLOAD_SEMAPHORE = None
|
| 22 |
+
_UPLOAD_SEM_VALUE = None
|
| 23 |
+
_DOWNLOAD_SEMAPHORE = None
|
| 24 |
+
_DOWNLOAD_SEM_VALUE = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _get_upload_semaphore() -> asyncio.Semaphore:
|
| 28 |
+
"""Return global semaphore for upload operations."""
|
| 29 |
+
value = max(1, int(get_config("asset.upload_concurrent")))
|
| 30 |
+
|
| 31 |
+
global _UPLOAD_SEMAPHORE, _UPLOAD_SEM_VALUE
|
| 32 |
+
if _UPLOAD_SEMAPHORE is None or value != _UPLOAD_SEM_VALUE:
|
| 33 |
+
_UPLOAD_SEM_VALUE = value
|
| 34 |
+
_UPLOAD_SEMAPHORE = asyncio.Semaphore(value)
|
| 35 |
+
return _UPLOAD_SEMAPHORE
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_download_semaphore() -> asyncio.Semaphore:
|
| 39 |
+
"""Return global semaphore for download operations."""
|
| 40 |
+
value = max(1, int(get_config("asset.download_concurrent")))
|
| 41 |
+
|
| 42 |
+
global _DOWNLOAD_SEMAPHORE, _DOWNLOAD_SEM_VALUE
|
| 43 |
+
if _DOWNLOAD_SEMAPHORE is None or value != _DOWNLOAD_SEM_VALUE:
|
| 44 |
+
_DOWNLOAD_SEM_VALUE = value
|
| 45 |
+
_DOWNLOAD_SEMAPHORE = asyncio.Semaphore(value)
|
| 46 |
+
return _DOWNLOAD_SEMAPHORE
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@asynccontextmanager
|
| 50 |
+
async def _file_lock(name: str, timeout: int = 10):
|
| 51 |
+
"""File lock guard."""
|
| 52 |
+
if fcntl is None:
|
| 53 |
+
yield
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
LOCK_DIR.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
lock_path = Path(LOCK_DIR) / f"{name}.lock"
|
| 58 |
+
fd = None
|
| 59 |
+
locked = False
|
| 60 |
+
start = time.monotonic()
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
fd = open(lock_path, "a+")
|
| 64 |
+
while True:
|
| 65 |
+
try:
|
| 66 |
+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
| 67 |
+
locked = True
|
| 68 |
+
break
|
| 69 |
+
except BlockingIOError:
|
| 70 |
+
if time.monotonic() - start >= timeout:
|
| 71 |
+
break
|
| 72 |
+
await asyncio.sleep(0.05)
|
| 73 |
+
if not locked:
|
| 74 |
+
raise TimeoutError(f"Failed to acquire lock: {name}")
|
| 75 |
+
yield
|
| 76 |
+
finally:
|
| 77 |
+
if fd:
|
| 78 |
+
if locked:
|
| 79 |
+
try:
|
| 80 |
+
fcntl.flock(fd, fcntl.LOCK_UN)
|
| 81 |
+
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
fd.close()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
__all__ = ["_get_upload_semaphore", "_get_download_semaphore", "_file_lock"]
|
app/services/grok/utils/process.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
响应处理器基类和通用工具
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
from typing import Any, AsyncGenerator, Optional, AsyncIterable, List, TypeVar
|
| 8 |
+
|
| 9 |
+
from app.core.config import get_config
|
| 10 |
+
from app.core.logger import logger
|
| 11 |
+
from app.core.exceptions import StreamIdleTimeoutError
|
| 12 |
+
from app.services.grok.utils.download import DownloadService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
T = TypeVar("T")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _is_http2_error(e: Exception) -> bool:
|
| 19 |
+
"""检查是否为 HTTP/2 流错误"""
|
| 20 |
+
err_str = str(e).lower()
|
| 21 |
+
return "http/2" in err_str or "curl: (92)" in err_str or "stream" in err_str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _normalize_line(line: Any) -> Optional[str]:
|
| 25 |
+
"""规范化流式响应行,兼容 SSE data 前缀与空行"""
|
| 26 |
+
if line is None:
|
| 27 |
+
return None
|
| 28 |
+
if isinstance(line, (bytes, bytearray)):
|
| 29 |
+
text = line.decode("utf-8", errors="ignore")
|
| 30 |
+
else:
|
| 31 |
+
text = str(line)
|
| 32 |
+
text = text.strip()
|
| 33 |
+
if not text:
|
| 34 |
+
return None
|
| 35 |
+
if text.startswith("data:"):
|
| 36 |
+
text = text[5:].strip()
|
| 37 |
+
if text == "[DONE]":
|
| 38 |
+
return None
|
| 39 |
+
return text
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _collect_images(obj: Any) -> List[str]:
|
| 43 |
+
"""递归收集响应中的图片 URL"""
|
| 44 |
+
urls: List[str] = []
|
| 45 |
+
seen = set()
|
| 46 |
+
|
| 47 |
+
def add(url: str):
|
| 48 |
+
if not url or url in seen:
|
| 49 |
+
return
|
| 50 |
+
seen.add(url)
|
| 51 |
+
urls.append(url)
|
| 52 |
+
|
| 53 |
+
def walk(value: Any):
|
| 54 |
+
if isinstance(value, dict):
|
| 55 |
+
for key, item in value.items():
|
| 56 |
+
if key in {"generatedImageUrls", "imageUrls", "imageURLs"}:
|
| 57 |
+
if isinstance(item, list):
|
| 58 |
+
for url in item:
|
| 59 |
+
if isinstance(url, str):
|
| 60 |
+
add(url)
|
| 61 |
+
elif isinstance(item, str):
|
| 62 |
+
add(item)
|
| 63 |
+
continue
|
| 64 |
+
walk(item)
|
| 65 |
+
elif isinstance(value, list):
|
| 66 |
+
for item in value:
|
| 67 |
+
walk(item)
|
| 68 |
+
|
| 69 |
+
walk(obj)
|
| 70 |
+
return urls
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def _with_idle_timeout(
|
| 74 |
+
iterable: AsyncIterable[T], idle_timeout: float, model: str = ""
|
| 75 |
+
) -> AsyncGenerator[T, None]:
|
| 76 |
+
"""
|
| 77 |
+
包装异步迭代器,添加空闲超时检测
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
iterable: 原始异步迭代器
|
| 81 |
+
idle_timeout: 空闲超时时间(秒),0 表示禁用
|
| 82 |
+
model: 模型名称(用于日志)
|
| 83 |
+
"""
|
| 84 |
+
if idle_timeout <= 0:
|
| 85 |
+
async for item in iterable:
|
| 86 |
+
yield item
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
iterator = iterable.__aiter__()
|
| 90 |
+
|
| 91 |
+
async def _maybe_aclose(it):
|
| 92 |
+
aclose = getattr(it, "aclose", None)
|
| 93 |
+
if not aclose:
|
| 94 |
+
return
|
| 95 |
+
try:
|
| 96 |
+
await aclose()
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
while True:
|
| 101 |
+
try:
|
| 102 |
+
item = await asyncio.wait_for(iterator.__anext__(), timeout=idle_timeout)
|
| 103 |
+
yield item
|
| 104 |
+
except asyncio.TimeoutError:
|
| 105 |
+
logger.warning(
|
| 106 |
+
f"Stream idle timeout after {idle_timeout}s",
|
| 107 |
+
extra={"model": model, "idle_timeout": idle_timeout},
|
| 108 |
+
)
|
| 109 |
+
await _maybe_aclose(iterator)
|
| 110 |
+
raise StreamIdleTimeoutError(idle_timeout)
|
| 111 |
+
except asyncio.CancelledError:
|
| 112 |
+
await _maybe_aclose(iterator)
|
| 113 |
+
raise
|
| 114 |
+
except StopAsyncIteration:
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class BaseProcessor:
|
| 119 |
+
"""基础处理器"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, model: str, token: str = ""):
|
| 122 |
+
self.model = model
|
| 123 |
+
self.token = token
|
| 124 |
+
self.created = int(time.time())
|
| 125 |
+
self.app_url = get_config("app.app_url")
|
| 126 |
+
self._dl_service: Optional[DownloadService] = None
|
| 127 |
+
|
| 128 |
+
def _get_dl(self) -> DownloadService:
|
| 129 |
+
"""获取下载服务实例(复用)"""
|
| 130 |
+
if self._dl_service is None:
|
| 131 |
+
self._dl_service = DownloadService()
|
| 132 |
+
return self._dl_service
|
| 133 |
+
|
| 134 |
+
async def close(self):
|
| 135 |
+
"""释放下载服务资源"""
|
| 136 |
+
if self._dl_service:
|
| 137 |
+
await self._dl_service.close()
|
| 138 |
+
self._dl_service = None
|
| 139 |
+
|
| 140 |
+
async def process_url(self, path: str, media_type: str = "image") -> str:
|
| 141 |
+
"""处理资产 URL"""
|
| 142 |
+
dl_service = self._get_dl()
|
| 143 |
+
return await dl_service.resolve_url(path, self.token, media_type)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
__all__ = [
|
| 147 |
+
"BaseProcessor",
|
| 148 |
+
"_with_idle_timeout",
|
| 149 |
+
"_normalize_line",
|
| 150 |
+
"_collect_images",
|
| 151 |
+
"_is_http2_error",
|
| 152 |
+
]
|
app/services/grok/utils/response.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Response formatting utilities for OpenAI-compatible API responses.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_response_id() -> str:
|
| 12 |
+
"""Generate a unique response ID."""
|
| 13 |
+
return f"chatcmpl-{int(time.time() * 1000)}{os.urandom(4).hex()}"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_chat_chunk(
|
| 17 |
+
response_id: str,
|
| 18 |
+
model: str,
|
| 19 |
+
content: str,
|
| 20 |
+
index: int = 0,
|
| 21 |
+
role: str = "assistant",
|
| 22 |
+
is_final: bool = False,
|
| 23 |
+
) -> dict:
|
| 24 |
+
"""
|
| 25 |
+
Create an OpenAI-compatible chat completion chunk.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
response_id: Unique response ID
|
| 29 |
+
model: Model name
|
| 30 |
+
content: Content to send
|
| 31 |
+
index: Choice index
|
| 32 |
+
role: Role (assistant)
|
| 33 |
+
is_final: Whether this is the final chunk (includes finish_reason)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Chat completion chunk dict
|
| 37 |
+
"""
|
| 38 |
+
choice: dict = {
|
| 39 |
+
"index": index,
|
| 40 |
+
"delta": {
|
| 41 |
+
"role": role,
|
| 42 |
+
"content": content,
|
| 43 |
+
},
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
if is_final:
|
| 47 |
+
choice["finish_reason"] = "stop"
|
| 48 |
+
|
| 49 |
+
chunk: dict = {
|
| 50 |
+
"id": response_id,
|
| 51 |
+
"object": "chat.completion.chunk",
|
| 52 |
+
"created": int(time.time()),
|
| 53 |
+
"model": model,
|
| 54 |
+
"choices": [choice],
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
if is_final:
|
| 58 |
+
chunk["usage"] = {
|
| 59 |
+
"total_tokens": 0,
|
| 60 |
+
"input_tokens": 0,
|
| 61 |
+
"output_tokens": 0,
|
| 62 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
return chunk
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def make_chat_response(
|
| 69 |
+
model: str,
|
| 70 |
+
content: str,
|
| 71 |
+
response_id: Optional[str] = None,
|
| 72 |
+
index: int = 0,
|
| 73 |
+
usage: Optional[dict] = None,
|
| 74 |
+
) -> dict:
|
| 75 |
+
"""
|
| 76 |
+
Create an OpenAI-compatible non-streaming chat completion response.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
model: Model name
|
| 80 |
+
content: Response content
|
| 81 |
+
response_id: Unique response ID (generated if not provided)
|
| 82 |
+
index: Choice index
|
| 83 |
+
usage: Custom usage dict (defaults to zeros)
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Chat completion response dict
|
| 87 |
+
"""
|
| 88 |
+
if response_id is None:
|
| 89 |
+
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
| 90 |
+
|
| 91 |
+
if usage is None:
|
| 92 |
+
usage = {
|
| 93 |
+
"total_tokens": 0,
|
| 94 |
+
"input_tokens": 0,
|
| 95 |
+
"output_tokens": 0,
|
| 96 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"id": response_id,
|
| 101 |
+
"object": "chat.completion",
|
| 102 |
+
"created": int(time.time()),
|
| 103 |
+
"model": model,
|
| 104 |
+
"choices": [
|
| 105 |
+
{
|
| 106 |
+
"index": index,
|
| 107 |
+
"message": {
|
| 108 |
+
"role": "assistant",
|
| 109 |
+
"content": content,
|
| 110 |
+
"refusal": None,
|
| 111 |
+
},
|
| 112 |
+
"finish_reason": "stop",
|
| 113 |
+
}
|
| 114 |
+
],
|
| 115 |
+
"usage": usage,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def wrap_image_content(content: str, response_format: str = "url") -> str:
|
| 120 |
+
"""
|
| 121 |
+
Wrap image content in markdown format for chat interface.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
content: Image URL or base64 data
|
| 125 |
+
response_format: "url" or "b64_json"/"base64"
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Markdown-wrapped image content
|
| 129 |
+
"""
|
| 130 |
+
if not content:
|
| 131 |
+
return content
|
| 132 |
+
|
| 133 |
+
if response_format == "url":
|
| 134 |
+
return f""
|
| 135 |
+
else:
|
| 136 |
+
return f""
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
__all__ = [
|
| 140 |
+
"make_response_id",
|
| 141 |
+
"make_chat_chunk",
|
| 142 |
+
"make_chat_response",
|
| 143 |
+
"wrap_image_content",
|
| 144 |
+
]
|
app/services/grok/utils/retry.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Retry helpers for token switching.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Set
|
| 6 |
+
|
| 7 |
+
from app.core.exceptions import UpstreamException
|
| 8 |
+
from app.services.grok.services.model import ModelService
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def pick_token(
|
| 12 |
+
token_mgr,
|
| 13 |
+
model_id: str,
|
| 14 |
+
tried: Set[str],
|
| 15 |
+
preferred: Optional[str] = None,
|
| 16 |
+
prefer_tags: Optional[Set[str]] = None,
|
| 17 |
+
) -> Optional[str]:
|
| 18 |
+
if preferred and preferred not in tried:
|
| 19 |
+
return preferred
|
| 20 |
+
|
| 21 |
+
token = None
|
| 22 |
+
for pool_name in ModelService.pool_candidates_for_model(model_id):
|
| 23 |
+
token = token_mgr.get_token(pool_name, exclude=tried, prefer_tags=prefer_tags)
|
| 24 |
+
if token:
|
| 25 |
+
break
|
| 26 |
+
|
| 27 |
+
if not token and not tried:
|
| 28 |
+
result = await token_mgr.refresh_cooling_tokens()
|
| 29 |
+
if result.get("recovered", 0) > 0:
|
| 30 |
+
for pool_name in ModelService.pool_candidates_for_model(model_id):
|
| 31 |
+
token = token_mgr.get_token(pool_name, prefer_tags=prefer_tags)
|
| 32 |
+
if token:
|
| 33 |
+
break
|
| 34 |
+
|
| 35 |
+
return token
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def rate_limited(error: Exception) -> bool:
|
| 39 |
+
if not isinstance(error, UpstreamException):
|
| 40 |
+
return False
|
| 41 |
+
status = error.details.get("status") if error.details else None
|
| 42 |
+
code = error.details.get("error_code") if error.details else None
|
| 43 |
+
return status == 429 or code == "rate_limit_exceeded"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def transient_upstream(error: Exception) -> bool:
|
| 47 |
+
"""Whether error is likely transient and safe to retry with another token."""
|
| 48 |
+
if not isinstance(error, UpstreamException):
|
| 49 |
+
return False
|
| 50 |
+
details = error.details or {}
|
| 51 |
+
status = details.get("status")
|
| 52 |
+
err = str(details.get("error") or error).lower()
|
| 53 |
+
transient_status = {408, 500, 502, 503, 504}
|
| 54 |
+
if status in transient_status:
|
| 55 |
+
return True
|
| 56 |
+
timeout_markers = (
|
| 57 |
+
"timed out",
|
| 58 |
+
"timeout",
|
| 59 |
+
"connection reset",
|
| 60 |
+
"temporarily unavailable",
|
| 61 |
+
"http2",
|
| 62 |
+
)
|
| 63 |
+
return any(marker in err for marker in timeout_markers)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
__all__ = ["pick_token", "rate_limited", "transient_upstream"]
|
app/services/grok/utils/stream.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
流式响应通用工具
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import AsyncGenerator
|
| 6 |
+
|
| 7 |
+
from app.core.logger import logger
|
| 8 |
+
from app.services.grok.services.model import ModelService
|
| 9 |
+
from app.services.token import EffortType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def wrap_stream_with_usage(
|
| 13 |
+
stream: AsyncGenerator, token_mgr, token: str, model: str
|
| 14 |
+
) -> AsyncGenerator:
|
| 15 |
+
"""
|
| 16 |
+
包装流式响应,在完成时记录使用
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
stream: 原始 AsyncGenerator
|
| 20 |
+
token_mgr: TokenManager 实例
|
| 21 |
+
token: Token 字符串
|
| 22 |
+
model: 模型名称
|
| 23 |
+
"""
|
| 24 |
+
success = False
|
| 25 |
+
try:
|
| 26 |
+
async for chunk in stream:
|
| 27 |
+
yield chunk
|
| 28 |
+
success = True
|
| 29 |
+
finally:
|
| 30 |
+
if success:
|
| 31 |
+
try:
|
| 32 |
+
model_info = ModelService.get(model)
|
| 33 |
+
effort = (
|
| 34 |
+
EffortType.HIGH
|
| 35 |
+
if (model_info and model_info.cost.value == "high")
|
| 36 |
+
else EffortType.LOW
|
| 37 |
+
)
|
| 38 |
+
await token_mgr.consume(token, effort)
|
| 39 |
+
logger.debug(
|
| 40 |
+
f"Stream completed, recorded usage for token {token[:10]}... (effort={effort.value})"
|
| 41 |
+
)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.warning(f"Failed to record stream usage: {e}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
__all__ = ["wrap_stream_with_usage"]
|
app/services/grok/utils/tool_call.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tool call utilities for OpenAI-compatible function calling.
|
| 3 |
+
|
| 4 |
+
Provides prompt-based emulation of tool calls by injecting tool definitions
|
| 5 |
+
into the system prompt and parsing structured responses.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
import uuid
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_tool_prompt(
|
| 15 |
+
tools: List[Dict[str, Any]],
|
| 16 |
+
tool_choice: Optional[Any] = None,
|
| 17 |
+
parallel_tool_calls: bool = True,
|
| 18 |
+
) -> str:
|
| 19 |
+
"""Generate a system prompt block describing available tools.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
tools: List of OpenAI-format tool definitions.
|
| 23 |
+
tool_choice: "auto", "required", "none", or {"type":"function","function":{"name":"..."}}.
|
| 24 |
+
parallel_tool_calls: Whether multiple tool calls are allowed.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
System prompt string to prepend to the conversation.
|
| 28 |
+
"""
|
| 29 |
+
if not tools:
|
| 30 |
+
return ""
|
| 31 |
+
|
| 32 |
+
# tool_choice="none" means don't mention tools at all
|
| 33 |
+
if tool_choice == "none":
|
| 34 |
+
return ""
|
| 35 |
+
|
| 36 |
+
lines = [
|
| 37 |
+
"# Available Tools",
|
| 38 |
+
"",
|
| 39 |
+
"You have access to the following tools. To call a tool, output a <tool_call> block with a JSON object containing \"name\" and \"arguments\".",
|
| 40 |
+
"",
|
| 41 |
+
"Format:",
|
| 42 |
+
"<tool_call>",
|
| 43 |
+
'{"name": "function_name", "arguments": {"param": "value"}}',
|
| 44 |
+
"</tool_call>",
|
| 45 |
+
"",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
if parallel_tool_calls:
|
| 49 |
+
lines.append("You may make multiple tool calls in a single response by using multiple <tool_call> blocks.")
|
| 50 |
+
lines.append("")
|
| 51 |
+
|
| 52 |
+
# Describe each tool
|
| 53 |
+
lines.append("## Tool Definitions")
|
| 54 |
+
lines.append("")
|
| 55 |
+
for tool in tools:
|
| 56 |
+
if tool.get("type") != "function":
|
| 57 |
+
continue
|
| 58 |
+
func = tool.get("function", {})
|
| 59 |
+
name = func.get("name", "")
|
| 60 |
+
desc = func.get("description", "")
|
| 61 |
+
params = func.get("parameters", {})
|
| 62 |
+
|
| 63 |
+
lines.append(f"### {name}")
|
| 64 |
+
if desc:
|
| 65 |
+
lines.append(f"{desc}")
|
| 66 |
+
if params:
|
| 67 |
+
lines.append(f"Parameters: {json.dumps(params, ensure_ascii=False)}")
|
| 68 |
+
lines.append("")
|
| 69 |
+
|
| 70 |
+
# Handle tool_choice directives
|
| 71 |
+
if tool_choice == "required":
|
| 72 |
+
lines.append("IMPORTANT: You MUST call at least one tool in your response. Do not respond with only text.")
|
| 73 |
+
elif isinstance(tool_choice, dict):
|
| 74 |
+
func_info = tool_choice.get("function", {})
|
| 75 |
+
forced_name = func_info.get("name", "")
|
| 76 |
+
if forced_name:
|
| 77 |
+
lines.append(f"IMPORTANT: You MUST call the tool \"{forced_name}\" in your response.")
|
| 78 |
+
else:
|
| 79 |
+
# "auto" or default
|
| 80 |
+
lines.append("Decide whether to call a tool based on the user's request. If you don't need a tool, respond normally with text only.")
|
| 81 |
+
|
| 82 |
+
lines.append("")
|
| 83 |
+
lines.append("When you call a tool, you may include text before or after the <tool_call> blocks, but the tool call blocks must be valid JSON.")
|
| 84 |
+
|
| 85 |
+
return "\n".join(lines)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
_TOOL_CALL_RE = re.compile(
|
| 89 |
+
r"<tool_call>\s*(.*?)\s*</tool_call>",
|
| 90 |
+
re.DOTALL,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _strip_code_fences(text: str) -> str:
|
| 95 |
+
if not text:
|
| 96 |
+
return text
|
| 97 |
+
cleaned = text.strip()
|
| 98 |
+
if cleaned.startswith("```"):
|
| 99 |
+
cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
|
| 100 |
+
cleaned = re.sub(r"\s*```$", "", cleaned)
|
| 101 |
+
return cleaned.strip()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _extract_json_object(text: str) -> str:
|
| 105 |
+
if not text:
|
| 106 |
+
return text
|
| 107 |
+
start = text.find("{")
|
| 108 |
+
if start == -1:
|
| 109 |
+
return text
|
| 110 |
+
end = text.rfind("}")
|
| 111 |
+
if end == -1:
|
| 112 |
+
return text[start:]
|
| 113 |
+
if end < start:
|
| 114 |
+
return text
|
| 115 |
+
return text[start : end + 1]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _remove_trailing_commas(text: str) -> str:
|
| 119 |
+
if not text:
|
| 120 |
+
return text
|
| 121 |
+
return re.sub(r",\s*([}\]])", r"\1", text)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _balance_braces(text: str) -> str:
|
| 125 |
+
if not text:
|
| 126 |
+
return text
|
| 127 |
+
open_count = 0
|
| 128 |
+
close_count = 0
|
| 129 |
+
in_string = False
|
| 130 |
+
escape = False
|
| 131 |
+
for ch in text:
|
| 132 |
+
if escape:
|
| 133 |
+
escape = False
|
| 134 |
+
continue
|
| 135 |
+
if ch == "\\" and in_string:
|
| 136 |
+
escape = True
|
| 137 |
+
continue
|
| 138 |
+
if ch == '"':
|
| 139 |
+
in_string = not in_string
|
| 140 |
+
continue
|
| 141 |
+
if in_string:
|
| 142 |
+
continue
|
| 143 |
+
if ch == "{":
|
| 144 |
+
open_count += 1
|
| 145 |
+
elif ch == "}":
|
| 146 |
+
close_count += 1
|
| 147 |
+
if open_count > close_count:
|
| 148 |
+
text = text + ("}" * (open_count - close_count))
|
| 149 |
+
return text
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _repair_json(text: str) -> Optional[Any]:
|
| 153 |
+
if not text:
|
| 154 |
+
return None
|
| 155 |
+
cleaned = _strip_code_fences(text)
|
| 156 |
+
cleaned = _extract_json_object(cleaned)
|
| 157 |
+
cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
|
| 158 |
+
cleaned = cleaned.replace("\n", " ")
|
| 159 |
+
cleaned = _remove_trailing_commas(cleaned)
|
| 160 |
+
cleaned = _balance_braces(cleaned)
|
| 161 |
+
try:
|
| 162 |
+
return json.loads(cleaned)
|
| 163 |
+
except json.JSONDecodeError:
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def parse_tool_call_block(
|
| 168 |
+
raw_json: str,
|
| 169 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
| 170 |
+
) -> Optional[Dict[str, Any]]:
|
| 171 |
+
if not raw_json:
|
| 172 |
+
return None
|
| 173 |
+
parsed = None
|
| 174 |
+
try:
|
| 175 |
+
parsed = json.loads(raw_json)
|
| 176 |
+
except json.JSONDecodeError:
|
| 177 |
+
parsed = _repair_json(raw_json)
|
| 178 |
+
if not isinstance(parsed, dict):
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
name = parsed.get("name")
|
| 182 |
+
arguments = parsed.get("arguments", {})
|
| 183 |
+
if not name:
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
valid_names = set()
|
| 187 |
+
if tools:
|
| 188 |
+
for tool in tools:
|
| 189 |
+
func = tool.get("function", {})
|
| 190 |
+
tool_name = func.get("name")
|
| 191 |
+
if tool_name:
|
| 192 |
+
valid_names.add(tool_name)
|
| 193 |
+
if valid_names and name not in valid_names:
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
if isinstance(arguments, dict):
|
| 197 |
+
arguments_str = json.dumps(arguments, ensure_ascii=False)
|
| 198 |
+
elif isinstance(arguments, str):
|
| 199 |
+
arguments_str = arguments
|
| 200 |
+
else:
|
| 201 |
+
arguments_str = json.dumps(arguments, ensure_ascii=False)
|
| 202 |
+
|
| 203 |
+
return {
|
| 204 |
+
"id": f"call_{uuid.uuid4().hex[:24]}",
|
| 205 |
+
"type": "function",
|
| 206 |
+
"function": {"name": name, "arguments": arguments_str},
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def parse_tool_calls(
|
| 211 |
+
content: str,
|
| 212 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
| 213 |
+
) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
| 214 |
+
"""Parse tool call blocks from model output.
|
| 215 |
+
|
| 216 |
+
Detects ``<tool_call>...</tool_call>`` blocks, parses JSON from each block,
|
| 217 |
+
and returns OpenAI-format tool call objects.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
content: Raw model output text.
|
| 221 |
+
tools: Optional list of tool definitions for name validation.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (text_content, tool_calls_list).
|
| 225 |
+
- text_content: text outside <tool_call> blocks (None if empty).
|
| 226 |
+
- tool_calls_list: list of OpenAI tool call dicts, or None if no calls found.
|
| 227 |
+
"""
|
| 228 |
+
if not content:
|
| 229 |
+
return content, None
|
| 230 |
+
|
| 231 |
+
matches = list(_TOOL_CALL_RE.finditer(content))
|
| 232 |
+
if not matches:
|
| 233 |
+
return content, None
|
| 234 |
+
|
| 235 |
+
tool_calls = []
|
| 236 |
+
for match in matches:
|
| 237 |
+
raw_json = match.group(1).strip()
|
| 238 |
+
tool_call = parse_tool_call_block(raw_json, tools)
|
| 239 |
+
if tool_call:
|
| 240 |
+
tool_calls.append(tool_call)
|
| 241 |
+
|
| 242 |
+
if not tool_calls:
|
| 243 |
+
return content, None
|
| 244 |
+
|
| 245 |
+
# Extract text outside of tool_call blocks
|
| 246 |
+
text_parts = []
|
| 247 |
+
last_end = 0
|
| 248 |
+
for match in matches:
|
| 249 |
+
before = content[last_end:match.start()]
|
| 250 |
+
if before.strip():
|
| 251 |
+
text_parts.append(before.strip())
|
| 252 |
+
last_end = match.end()
|
| 253 |
+
trailing = content[last_end:]
|
| 254 |
+
if trailing.strip():
|
| 255 |
+
text_parts.append(trailing.strip())
|
| 256 |
+
|
| 257 |
+
text_content = "\n".join(text_parts) if text_parts else None
|
| 258 |
+
|
| 259 |
+
return text_content, tool_calls
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def format_tool_history(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 263 |
+
"""Convert assistant messages with tool_calls and tool role messages into text format.
|
| 264 |
+
|
| 265 |
+
Since Grok's web API only accepts a single message string, this converts
|
| 266 |
+
tool-related messages back to a text representation for multi-turn conversations.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
messages: List of OpenAI-format messages that may contain tool_calls and tool roles.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
List of messages with tool content converted to text format.
|
| 273 |
+
"""
|
| 274 |
+
result = []
|
| 275 |
+
for msg in messages:
|
| 276 |
+
role = msg.get("role", "")
|
| 277 |
+
content = msg.get("content")
|
| 278 |
+
tool_calls = msg.get("tool_calls")
|
| 279 |
+
tool_call_id = msg.get("tool_call_id")
|
| 280 |
+
name = msg.get("name")
|
| 281 |
+
|
| 282 |
+
if role == "assistant" and tool_calls:
|
| 283 |
+
# Convert assistant tool_calls to text representation
|
| 284 |
+
parts = []
|
| 285 |
+
if content:
|
| 286 |
+
parts.append(content if isinstance(content, str) else str(content))
|
| 287 |
+
for tc in tool_calls:
|
| 288 |
+
func = tc.get("function", {})
|
| 289 |
+
tc_name = func.get("name", "")
|
| 290 |
+
tc_args = func.get("arguments", "{}")
|
| 291 |
+
tc_id = tc.get("id", "")
|
| 292 |
+
parts.append(f'<tool_call>{{"name":"{tc_name}","arguments":{tc_args}}}</tool_call>')
|
| 293 |
+
result.append({
|
| 294 |
+
"role": "assistant",
|
| 295 |
+
"content": "\n".join(parts),
|
| 296 |
+
})
|
| 297 |
+
|
| 298 |
+
elif role == "tool":
|
| 299 |
+
# Convert tool result to text format
|
| 300 |
+
tool_name = name or "unknown"
|
| 301 |
+
call_id = tool_call_id or ""
|
| 302 |
+
content_str = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) if content else ""
|
| 303 |
+
result.append({
|
| 304 |
+
"role": "user",
|
| 305 |
+
"content": f"tool ({tool_name}, {call_id}): {content_str}",
|
| 306 |
+
})
|
| 307 |
+
|
| 308 |
+
else:
|
| 309 |
+
result.append(msg)
|
| 310 |
+
|
| 311 |
+
return result
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
__all__ = [
|
| 315 |
+
"build_tool_prompt",
|
| 316 |
+
"parse_tool_calls",
|
| 317 |
+
"format_tool_history",
|
| 318 |
+
"parse_tool_call_block",
|
| 319 |
+
]
|