ricebug commited on
Commit
6e18b6a
·
verified ·
1 Parent(s): 3428636

Upload 126 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +76 -0
  2. app/.DS_Store +0 -0
  3. app/api/pages/__init__.py +13 -0
  4. app/api/pages/admin.py +32 -0
  5. app/api/pages/public.py +51 -0
  6. app/api/v1/admin_api/__init__.py +15 -0
  7. app/api/v1/admin_api/cache.py +445 -0
  8. app/api/v1/admin_api/config.py +53 -0
  9. app/api/v1/admin_api/token.py +395 -0
  10. app/api/v1/chat.py +862 -0
  11. app/api/v1/files.py +69 -0
  12. app/api/v1/image.py +452 -0
  13. app/api/v1/models.py +28 -0
  14. app/api/v1/public_api/__init__.py +18 -0
  15. app/api/v1/public_api/imagine.py +505 -0
  16. app/api/v1/public_api/video.py +274 -0
  17. app/api/v1/public_api/voice.py +80 -0
  18. app/api/v1/response.py +81 -0
  19. app/api/v1/video.py +3 -0
  20. app/core/auth.py +198 -0
  21. app/core/batch.py +233 -0
  22. app/core/config.py +326 -0
  23. app/core/exceptions.py +232 -0
  24. app/core/logger.py +151 -0
  25. app/core/response_middleware.py +85 -0
  26. app/core/storage.py +1478 -0
  27. app/services/cf_refresh/README.md +49 -0
  28. app/services/cf_refresh/__init__.py +5 -0
  29. app/services/cf_refresh/config.py +41 -0
  30. app/services/cf_refresh/scheduler.py +98 -0
  31. app/services/cf_refresh/solver.py +122 -0
  32. app/services/grok/batch_services/assets.py +234 -0
  33. app/services/grok/batch_services/nsfw.py +112 -0
  34. app/services/grok/batch_services/usage.py +89 -0
  35. app/services/grok/defaults.py +34 -0
  36. app/services/grok/services/chat.py +1115 -0
  37. app/services/grok/services/image.py +794 -0
  38. app/services/grok/services/image_edit.py +567 -0
  39. app/services/grok/services/model.py +270 -0
  40. app/services/grok/services/responses.py +824 -0
  41. app/services/grok/services/video.py +688 -0
  42. app/services/grok/services/voice.py +31 -0
  43. app/services/grok/utils/cache.py +110 -0
  44. app/services/grok/utils/download.py +298 -0
  45. app/services/grok/utils/locks.py +86 -0
  46. app/services/grok/utils/process.py +152 -0
  47. app/services/grok/utils/response.py +144 -0
  48. app/services/grok/utils/retry.py +66 -0
  49. app/services/grok/utils/stream.py +46 -0
  50. app/services/grok/utils/tool_call.py +319 -0
Dockerfile ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13-alpine AS builder
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ TZ=Asia/Shanghai \
6
+ # 把 uv 包安装到系统 Python 环境
7
+ UV_PROJECT_ENVIRONMENT=/opt/venv
8
+
9
+ # 确保 uv 的 bin 目录
10
+ ENV PATH="$UV_PROJECT_ENVIRONMENT/bin:$PATH"
11
+
12
+ RUN apk add --no-cache \
13
+ tzdata \
14
+ ca-certificates \
15
+ build-base \
16
+ linux-headers \
17
+ libffi-dev \
18
+ openssl-dev \
19
+ curl-dev \
20
+ cargo \
21
+ rust
22
+
23
+ WORKDIR /app
24
+
25
+ # 安装 uv
26
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
27
+
28
+ COPY pyproject.toml uv.lock ./
29
+
30
+ RUN uv sync --frozen --no-dev --no-install-project \
31
+ && find /opt/venv -type d -name "__pycache__" -prune -exec rm -rf {} + \
32
+ && find /opt/venv -type f -name "*.pyc" -delete \
33
+ && find /opt/venv -type d -name "tests" -prune -exec rm -rf {} + \
34
+ && find /opt/venv -type d -name "test" -prune -exec rm -rf {} + \
35
+ && find /opt/venv -type d -name "testing" -prune -exec rm -rf {} + \
36
+ && find /opt/venv -type f -name "*.so" -exec strip --strip-unneeded {} + || true \
37
+ && rm -rf /root/.cache /tmp/uv-cache
38
+
39
+ FROM python:3.13-alpine
40
+
41
+ ENV PYTHONDONTWRITEBYTECODE=1 \
42
+ PYTHONUNBUFFERED=1 \
43
+ TZ=Asia/Shanghai \
44
+ VIRTUAL_ENV=/opt/venv
45
+
46
+ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
47
+
48
+ RUN apk add --no-cache \
49
+ tzdata \
50
+ ca-certificates \
51
+ libffi \
52
+ openssl \
53
+ libgcc \
54
+ libstdc++ \
55
+ libcurl
56
+
57
+ WORKDIR /app
58
+
59
+ COPY --from=builder /opt/venv /opt/venv
60
+
61
+ COPY config.defaults.toml ./
62
+ COPY app ./app
63
+ COPY main.py ./
64
+ COPY scripts ./scripts
65
+
66
+ RUN mkdir -p /app/data /app/logs \
67
+ && chmod +x /app/scripts/entrypoint.sh
68
+
69
+ RUN chmod +x /app/scripts/entrypoint.sh
70
+ RUN chmod +x /app/scripts/init_storage.sh
71
+
72
+ EXPOSE 7860
73
+
74
+ ENTRYPOINT ["/app/scripts/entrypoint.sh"]
75
+
76
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/api/pages/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UI pages router."""
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from app.api.pages.admin import router as admin_router
6
+ from app.api.pages.public import router as public_router
7
+
8
+ router = APIRouter()
9
+
10
+ router.include_router(public_router)
11
+ router.include_router(admin_router)
12
+
13
+ __all__ = ["router"]
app/api/pages/admin.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from fastapi import APIRouter
4
+ from fastapi.responses import FileResponse, RedirectResponse
5
+
6
+ router = APIRouter()
7
+ STATIC_DIR = Path(__file__).resolve().parents[2] / "static"
8
+
9
+
10
+ @router.get("/admin", include_in_schema=False)
11
+ async def admin_root():
12
+ return RedirectResponse(url="/admin/login")
13
+
14
+
15
+ @router.get("/admin/login", include_in_schema=False)
16
+ async def admin_login():
17
+ return FileResponse(STATIC_DIR / "admin/pages/login.html")
18
+
19
+
20
+ @router.get("/admin/config", include_in_schema=False)
21
+ async def admin_config():
22
+ return FileResponse(STATIC_DIR / "admin/pages/config.html")
23
+
24
+
25
+ @router.get("/admin/cache", include_in_schema=False)
26
+ async def admin_cache():
27
+ return FileResponse(STATIC_DIR / "admin/pages/cache.html")
28
+
29
+
30
+ @router.get("/admin/token", include_in_schema=False)
31
+ async def admin_token():
32
+ return FileResponse(STATIC_DIR / "admin/pages/token.html")
app/api/pages/public.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+ from fastapi.responses import FileResponse, RedirectResponse
5
+
6
+ from app.core.auth import is_public_enabled
7
+
8
+ router = APIRouter()
9
+ STATIC_DIR = Path(__file__).resolve().parents[2] / "static"
10
+
11
+
12
+ @router.get("/", include_in_schema=False)
13
+ async def root():
14
+ if is_public_enabled():
15
+ return RedirectResponse(url="/login")
16
+ return RedirectResponse(url="/admin/login")
17
+
18
+
19
+ @router.get("/login", include_in_schema=False)
20
+ async def public_login():
21
+ if not is_public_enabled():
22
+ raise HTTPException(status_code=404, detail="Not Found")
23
+ return FileResponse(STATIC_DIR / "public/pages/login.html")
24
+
25
+
26
+ @router.get("/imagine", include_in_schema=False)
27
+ async def public_imagine():
28
+ if not is_public_enabled():
29
+ raise HTTPException(status_code=404, detail="Not Found")
30
+ return FileResponse(STATIC_DIR / "public/pages/imagine.html")
31
+
32
+
33
+ @router.get("/voice", include_in_schema=False)
34
+ async def public_voice():
35
+ if not is_public_enabled():
36
+ raise HTTPException(status_code=404, detail="Not Found")
37
+ return FileResponse(STATIC_DIR / "public/pages/voice.html")
38
+
39
+
40
+ @router.get("/video", include_in_schema=False)
41
+ async def public_video():
42
+ if not is_public_enabled():
43
+ raise HTTPException(status_code=404, detail="Not Found")
44
+ return FileResponse(STATIC_DIR / "public/pages/video.html")
45
+
46
+
47
+ @router.get("/chat", include_in_schema=False)
48
+ async def public_chat():
49
+ if not is_public_enabled():
50
+ raise HTTPException(status_code=404, detail="Not Found")
51
+ return FileResponse(STATIC_DIR / "public/pages/chat.html")
app/api/v1/admin_api/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Admin API router (app_key protected)."""
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from app.api.v1.admin_api.cache import router as cache_router
6
+ from app.api.v1.admin_api.config import router as config_router
7
+ from app.api.v1.admin_api.token import router as tokens_router
8
+
9
+ router = APIRouter()
10
+
11
+ router.include_router(config_router)
12
+ router.include_router(tokens_router)
13
+ router.include_router(cache_router)
14
+
15
+ __all__ = ["router"]
app/api/v1/admin_api/cache.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException, Query, Request
4
+
5
+ from app.core.auth import verify_app_key
6
+ from app.core.batch import create_task, expire_task
7
+ from app.services.grok.batch_services.assets import ListService, DeleteService
8
+ from app.services.token.manager import get_token_manager
9
+ router = APIRouter()
10
+
11
+
12
+ @router.get("/cache", dependencies=[Depends(verify_app_key)])
13
+ async def cache_stats(request: Request):
14
+ """获取缓存统计"""
15
+ from app.services.grok.utils.cache import CacheService
16
+
17
+ try:
18
+ cache_service = CacheService()
19
+ image_stats = cache_service.get_stats("image")
20
+ video_stats = cache_service.get_stats("video")
21
+
22
+ mgr = await get_token_manager()
23
+ pools = mgr.pools
24
+ accounts = []
25
+ for pool_name, pool in pools.items():
26
+ for info in pool.list():
27
+ raw_token = (
28
+ info.token[4:] if info.token.startswith("sso=") else info.token
29
+ )
30
+ masked = (
31
+ f"{raw_token[:8]}...{raw_token[-16:]}"
32
+ if len(raw_token) > 24
33
+ else raw_token
34
+ )
35
+ accounts.append(
36
+ {
37
+ "token": raw_token,
38
+ "token_masked": masked,
39
+ "pool": pool_name,
40
+ "status": info.status,
41
+ "last_asset_clear_at": info.last_asset_clear_at,
42
+ }
43
+ )
44
+
45
+ scope = request.query_params.get("scope")
46
+ selected_token = request.query_params.get("token")
47
+ tokens_param = request.query_params.get("tokens")
48
+ selected_tokens = []
49
+ if tokens_param:
50
+ selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()]
51
+
52
+ online_stats = {
53
+ "count": 0,
54
+ "status": "unknown",
55
+ "token": None,
56
+ "last_asset_clear_at": None,
57
+ }
58
+ online_details = []
59
+ account_map = {a["token"]: a for a in accounts}
60
+ if selected_tokens:
61
+ total = 0
62
+ raw_results = await ListService.fetch_assets_details(
63
+ selected_tokens,
64
+ account_map,
65
+ )
66
+ for token, res in raw_results.items():
67
+ if res.get("ok"):
68
+ data = res.get("data", {})
69
+ detail = data.get("detail")
70
+ total += data.get("count", 0)
71
+ else:
72
+ account = account_map.get(token)
73
+ detail = {
74
+ "token": token,
75
+ "token_masked": account["token_masked"] if account else token,
76
+ "count": 0,
77
+ "status": f"error: {res.get('error')}",
78
+ "last_asset_clear_at": account["last_asset_clear_at"]
79
+ if account
80
+ else None,
81
+ }
82
+ if detail:
83
+ online_details.append(detail)
84
+ online_stats = {
85
+ "count": total,
86
+ "status": "ok" if selected_tokens else "no_token",
87
+ "token": None,
88
+ "last_asset_clear_at": None,
89
+ }
90
+ scope = "selected"
91
+ elif scope == "all":
92
+ total = 0
93
+ tokens = list(dict.fromkeys([account["token"] for account in accounts]))
94
+ raw_results = await ListService.fetch_assets_details(
95
+ tokens,
96
+ account_map,
97
+ )
98
+ for token, res in raw_results.items():
99
+ if res.get("ok"):
100
+ data = res.get("data", {})
101
+ detail = data.get("detail")
102
+ total += data.get("count", 0)
103
+ else:
104
+ account = account_map.get(token)
105
+ detail = {
106
+ "token": token,
107
+ "token_masked": account["token_masked"] if account else token,
108
+ "count": 0,
109
+ "status": f"error: {res.get('error')}",
110
+ "last_asset_clear_at": account["last_asset_clear_at"]
111
+ if account
112
+ else None,
113
+ }
114
+ if detail:
115
+ online_details.append(detail)
116
+ online_stats = {
117
+ "count": total,
118
+ "status": "ok" if accounts else "no_token",
119
+ "token": None,
120
+ "last_asset_clear_at": None,
121
+ }
122
+ else:
123
+ token = selected_token
124
+ if token:
125
+ raw_results = await ListService.fetch_assets_details(
126
+ [token],
127
+ account_map,
128
+ )
129
+ res = raw_results.get(token, {})
130
+ data = res.get("data", {})
131
+ detail = data.get("detail") if res.get("ok") else None
132
+ if detail:
133
+ online_stats = {
134
+ "count": data.get("count", 0),
135
+ "status": detail.get("status", "ok"),
136
+ "token": detail.get("token"),
137
+ "token_masked": detail.get("token_masked"),
138
+ "last_asset_clear_at": detail.get("last_asset_clear_at"),
139
+ }
140
+ else:
141
+ match = next((a for a in accounts if a["token"] == token), None)
142
+ online_stats = {
143
+ "count": 0,
144
+ "status": f"error: {res.get('error')}",
145
+ "token": token,
146
+ "token_masked": match["token_masked"] if match else token,
147
+ "last_asset_clear_at": match["last_asset_clear_at"]
148
+ if match
149
+ else None,
150
+ }
151
+ else:
152
+ online_stats = {
153
+ "count": 0,
154
+ "status": "not_loaded",
155
+ "token": None,
156
+ "last_asset_clear_at": None,
157
+ }
158
+
159
+ response = {
160
+ "local_image": image_stats,
161
+ "local_video": video_stats,
162
+ "online": online_stats,
163
+ "online_accounts": accounts,
164
+ "online_scope": scope or "none",
165
+ "online_details": online_details,
166
+ }
167
+ return response
168
+ except Exception as e:
169
+ raise HTTPException(status_code=500, detail=str(e))
170
+
171
+
172
+ @router.get("/cache/list", dependencies=[Depends(verify_app_key)])
173
+ async def list_local(
174
+ cache_type: str = "image",
175
+ type_: str = Query(default=None, alias="type"),
176
+ page: int = 1,
177
+ page_size: int = 1000,
178
+ ):
179
+ """列出本地缓存文件"""
180
+ from app.services.grok.utils.cache import CacheService
181
+
182
+ try:
183
+ if type_:
184
+ cache_type = type_
185
+ cache_service = CacheService()
186
+ result = cache_service.list_files(cache_type, page, page_size)
187
+ return {"status": "success", **result}
188
+ except Exception as e:
189
+ raise HTTPException(status_code=500, detail=str(e))
190
+
191
+
192
+ @router.post("/cache/clear", dependencies=[Depends(verify_app_key)])
193
+ async def clear_local(data: dict):
194
+ """清理本地缓存"""
195
+ from app.services.grok.utils.cache import CacheService
196
+
197
+ cache_type = data.get("type", "image")
198
+
199
+ try:
200
+ cache_service = CacheService()
201
+ result = cache_service.clear(cache_type)
202
+ return {"status": "success", "result": result}
203
+ except Exception as e:
204
+ raise HTTPException(status_code=500, detail=str(e))
205
+
206
+
207
+ @router.post("/cache/item/delete", dependencies=[Depends(verify_app_key)])
208
+ async def delete_local_item(data: dict):
209
+ """删除单个本地缓存文件"""
210
+ from app.services.grok.utils.cache import CacheService
211
+
212
+ cache_type = data.get("type", "image")
213
+ name = data.get("name")
214
+ if not name:
215
+ raise HTTPException(status_code=400, detail="Missing file name")
216
+ try:
217
+ cache_service = CacheService()
218
+ result = cache_service.delete_file(cache_type, name)
219
+ return {"status": "success", "result": result}
220
+ except Exception as e:
221
+ raise HTTPException(status_code=500, detail=str(e))
222
+
223
+
224
+ @router.post("/cache/online/clear", dependencies=[Depends(verify_app_key)])
225
+ async def clear_online(data: dict):
226
+ """清理在线缓存"""
227
+ try:
228
+ mgr = await get_token_manager()
229
+ tokens = data.get("tokens")
230
+
231
+ if isinstance(tokens, list):
232
+ token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()]
233
+ if not token_list:
234
+ raise HTTPException(status_code=400, detail="No tokens provided")
235
+
236
+ token_list = list(dict.fromkeys(token_list))
237
+
238
+ results = {}
239
+ raw_results = await DeleteService.clear_assets(
240
+ token_list,
241
+ mgr,
242
+ )
243
+ for token, res in raw_results.items():
244
+ if res.get("ok"):
245
+ results[token] = res.get("data", {})
246
+ else:
247
+ results[token] = {"status": "error", "error": res.get("error")}
248
+
249
+ return {"status": "success", "results": results}
250
+
251
+ token = data.get("token") or mgr.get_token()
252
+ if not token:
253
+ raise HTTPException(
254
+ status_code=400, detail="No available token to perform cleanup"
255
+ )
256
+
257
+ raw_results = await DeleteService.clear_assets(
258
+ [token],
259
+ mgr,
260
+ )
261
+ res = raw_results.get(token, {})
262
+ data = res.get("data", {})
263
+ if res.get("ok") and data.get("status") == "success":
264
+ return {"status": "success", "result": data.get("result")}
265
+ return {"status": "error", "error": data.get("error") or res.get("error")}
266
+ except Exception as e:
267
+ raise HTTPException(status_code=500, detail=str(e))
268
+
269
+
270
+ @router.post("/cache/online/clear/async", dependencies=[Depends(verify_app_key)])
271
+ async def clear_online_async(data: dict):
272
+ """清理在线缓存(异步批量 + SSE 进度)"""
273
+ mgr = await get_token_manager()
274
+ tokens = data.get("tokens")
275
+ if not isinstance(tokens, list):
276
+ raise HTTPException(status_code=400, detail="No tokens provided")
277
+
278
+ token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()]
279
+ if not token_list:
280
+ raise HTTPException(status_code=400, detail="No tokens provided")
281
+
282
+ task = create_task(len(token_list))
283
+
284
+ async def _run():
285
+ try:
286
+ async def _on_item(item: str, res: dict):
287
+ ok = bool(res.get("data", {}).get("ok"))
288
+ task.record(ok)
289
+
290
+ raw_results = await DeleteService.clear_assets(
291
+ token_list,
292
+ mgr,
293
+ include_ok=True,
294
+ on_item=_on_item,
295
+ should_cancel=lambda: task.cancelled,
296
+ )
297
+
298
+ if task.cancelled:
299
+ task.finish_cancelled()
300
+ return
301
+
302
+ results = {}
303
+ ok_count = 0
304
+ fail_count = 0
305
+ for token, res in raw_results.items():
306
+ data = res.get("data", {})
307
+ if data.get("ok"):
308
+ ok_count += 1
309
+ results[token] = {"status": "success", "result": data.get("result")}
310
+ else:
311
+ fail_count += 1
312
+ results[token] = {"status": "error", "error": data.get("error")}
313
+
314
+ result = {
315
+ "status": "success",
316
+ "summary": {
317
+ "total": len(token_list),
318
+ "ok": ok_count,
319
+ "fail": fail_count,
320
+ },
321
+ "results": results,
322
+ }
323
+ task.finish(result)
324
+ except Exception as e:
325
+ task.fail_task(str(e))
326
+ finally:
327
+ import asyncio
328
+ asyncio.create_task(expire_task(task.id, 300))
329
+
330
+ import asyncio
331
+ asyncio.create_task(_run())
332
+
333
+ return {
334
+ "status": "success",
335
+ "task_id": task.id,
336
+ "total": len(token_list),
337
+ }
338
+
339
+
340
+ @router.post("/cache/online/load/async", dependencies=[Depends(verify_app_key)])
341
+ async def load_cache_async(data: dict):
342
+ """在线资产统计(异步批量 + SSE 进度)"""
343
+ from app.services.grok.utils.cache import CacheService
344
+
345
+ mgr = await get_token_manager()
346
+
347
+ accounts = []
348
+ for pool_name, pool in mgr.pools.items():
349
+ for info in pool.list():
350
+ raw_token = info.token[4:] if info.token.startswith("sso=") else info.token
351
+ masked = (
352
+ f"{raw_token[:8]}...{raw_token[-16:]}"
353
+ if len(raw_token) > 24
354
+ else raw_token
355
+ )
356
+ accounts.append(
357
+ {
358
+ "token": raw_token,
359
+ "token_masked": masked,
360
+ "pool": pool_name,
361
+ "status": info.status,
362
+ "last_asset_clear_at": info.last_asset_clear_at,
363
+ }
364
+ )
365
+
366
+ account_map = {a["token"]: a for a in accounts}
367
+
368
+ tokens = data.get("tokens")
369
+ scope = data.get("scope")
370
+ selected_tokens: List[str] = []
371
+ if isinstance(tokens, list):
372
+ selected_tokens = [str(t).strip() for t in tokens if str(t).strip()]
373
+
374
+ if not selected_tokens and scope == "all":
375
+ selected_tokens = [account["token"] for account in accounts]
376
+ scope = "all"
377
+ elif selected_tokens:
378
+ scope = "selected"
379
+ else:
380
+ raise HTTPException(status_code=400, detail="No tokens provided")
381
+
382
+ task = create_task(len(selected_tokens))
383
+
384
+ async def _run():
385
+ try:
386
+ cache_service = CacheService()
387
+ image_stats = cache_service.get_stats("image")
388
+ video_stats = cache_service.get_stats("video")
389
+
390
+ async def _on_item(item: str, res: dict):
391
+ ok = bool(res.get("data", {}).get("ok"))
392
+ task.record(ok)
393
+
394
+ raw_results = await ListService.fetch_assets_details(
395
+ selected_tokens,
396
+ account_map,
397
+ include_ok=True,
398
+ on_item=_on_item,
399
+ should_cancel=lambda: task.cancelled,
400
+ )
401
+
402
+ if task.cancelled:
403
+ task.finish_cancelled()
404
+ return
405
+
406
+ online_details = []
407
+ total = 0
408
+ for token, res in raw_results.items():
409
+ data = res.get("data", {})
410
+ detail = data.get("detail")
411
+ if detail:
412
+ online_details.append(detail)
413
+ total += data.get("count", 0)
414
+
415
+ online_stats = {
416
+ "count": total,
417
+ "status": "ok" if selected_tokens else "no_token",
418
+ "token": None,
419
+ "last_asset_clear_at": None,
420
+ }
421
+
422
+ result = {
423
+ "local_image": image_stats,
424
+ "local_video": video_stats,
425
+ "online": online_stats,
426
+ "online_accounts": accounts,
427
+ "online_scope": scope or "none",
428
+ "online_details": online_details,
429
+ }
430
+ task.finish(result)
431
+ except Exception as e:
432
+ task.fail_task(str(e))
433
+ finally:
434
+ import asyncio
435
+ asyncio.create_task(expire_task(task.id, 300))
436
+
437
+ import asyncio
438
+ asyncio.create_task(_run())
439
+
440
+ return {
441
+ "status": "success",
442
+ "task_id": task.id,
443
+ "total": len(selected_tokens),
444
+ }
445
+
app/api/v1/admin_api/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.core.auth import verify_app_key
6
+ from app.core.config import config
7
+ from app.core.storage import get_storage as resolve_storage, LocalStorage, RedisStorage, SQLStorage
8
+
9
+ router = APIRouter()
10
+
11
+
12
+ @router.get("/verify", dependencies=[Depends(verify_app_key)])
13
+ async def admin_verify():
14
+ """验证后台访问密钥(app_key)"""
15
+ return {"status": "success"}
16
+
17
+
18
+ @router.get("/config", dependencies=[Depends(verify_app_key)])
19
+ async def get_config():
20
+ """获取当前配置"""
21
+ # 暴露原始配置字典
22
+ return config._config
23
+
24
+
25
+ @router.post("/config", dependencies=[Depends(verify_app_key)])
26
+ async def update_config(data: dict):
27
+ """更新配置"""
28
+ try:
29
+ await config.update(data)
30
+ return {"status": "success", "message": "配置已更新"}
31
+ except Exception as e:
32
+ raise HTTPException(status_code=500, detail=str(e))
33
+
34
+
35
+ @router.get("/storage", dependencies=[Depends(verify_app_key)])
36
+ async def get_storage_mode():
37
+ """获取当前存储模式"""
38
+ storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower()
39
+ if not storage_type:
40
+ storage = resolve_storage()
41
+ if isinstance(storage, LocalStorage):
42
+ storage_type = "local"
43
+ elif isinstance(storage, RedisStorage):
44
+ storage_type = "redis"
45
+ elif isinstance(storage, SQLStorage):
46
+ storage_type = {
47
+ "mysql": "mysql",
48
+ "mariadb": "mysql",
49
+ "postgres": "pgsql",
50
+ "postgresql": "pgsql",
51
+ "pgsql": "pgsql",
52
+ }.get(storage.dialect, storage.dialect)
53
+ return {"type": storage_type or "local"}
app/api/v1/admin_api/token.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import orjson
4
+ from fastapi import APIRouter, Depends, HTTPException, Request
5
+ from fastapi.responses import StreamingResponse
6
+
7
+ from app.core.auth import get_app_key, verify_app_key
8
+ from app.core.batch import create_task, expire_task, get_task
9
+ from app.core.logger import logger
10
+ from app.core.storage import get_storage
11
+ from app.services.grok.batch_services.usage import UsageService
12
+ from app.services.grok.batch_services.nsfw import NSFWService
13
+ from app.services.token.manager import get_token_manager
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ @router.get("/tokens", dependencies=[Depends(verify_app_key)])
19
+ async def get_tokens():
20
+ """获取所有 Token"""
21
+ storage = get_storage()
22
+ tokens = await storage.load_tokens()
23
+ return tokens or {}
24
+
25
+
26
+ @router.post("/tokens", dependencies=[Depends(verify_app_key)])
27
+ async def update_tokens(data: dict):
28
+ """更新 Token 信息"""
29
+ storage = get_storage()
30
+ try:
31
+ from app.services.token.models import TokenInfo
32
+
33
+ async with storage.acquire_lock("tokens_save", timeout=10):
34
+ existing = await storage.load_tokens() or {}
35
+ normalized = {}
36
+ allowed_fields = set(TokenInfo.model_fields.keys())
37
+ existing_map = {}
38
+ for pool_name, tokens in existing.items():
39
+ if not isinstance(tokens, list):
40
+ continue
41
+ pool_map = {}
42
+ for item in tokens:
43
+ if isinstance(item, str):
44
+ token_data = {"token": item}
45
+ elif isinstance(item, dict):
46
+ token_data = dict(item)
47
+ else:
48
+ continue
49
+ raw_token = token_data.get("token")
50
+ if isinstance(raw_token, str) and raw_token.startswith("sso="):
51
+ token_data["token"] = raw_token[4:]
52
+ token_key = token_data.get("token")
53
+ if isinstance(token_key, str):
54
+ pool_map[token_key] = token_data
55
+ existing_map[pool_name] = pool_map
56
+ for pool_name, tokens in (data or {}).items():
57
+ if not isinstance(tokens, list):
58
+ continue
59
+ pool_list = []
60
+ for item in tokens:
61
+ if isinstance(item, str):
62
+ token_data = {"token": item}
63
+ elif isinstance(item, dict):
64
+ token_data = dict(item)
65
+ else:
66
+ continue
67
+
68
+ raw_token = token_data.get("token")
69
+ if isinstance(raw_token, str) and raw_token.startswith("sso="):
70
+ token_data["token"] = raw_token[4:]
71
+
72
+ base = existing_map.get(pool_name, {}).get(
73
+ token_data.get("token"), {}
74
+ )
75
+ merged = dict(base)
76
+ merged.update(token_data)
77
+ if merged.get("tags") is None:
78
+ merged["tags"] = []
79
+
80
+ filtered = {k: v for k, v in merged.items() if k in allowed_fields}
81
+ try:
82
+ info = TokenInfo(**filtered)
83
+ pool_list.append(info.model_dump())
84
+ except Exception as e:
85
+ logger.warning(f"Skip invalid token in pool '{pool_name}': {e}")
86
+ continue
87
+ normalized[pool_name] = pool_list
88
+
89
+ await storage.save_tokens(normalized)
90
+ mgr = await get_token_manager()
91
+ await mgr.reload()
92
+ return {"status": "success", "message": "Token 已更新"}
93
+ except Exception as e:
94
+ raise HTTPException(status_code=500, detail=str(e))
95
+
96
+
97
+ @router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)])
98
+ async def refresh_tokens(data: dict):
99
+ """刷新 Token 状态"""
100
+ try:
101
+ mgr = await get_token_manager()
102
+ tokens = []
103
+ if isinstance(data.get("token"), str) and data["token"].strip():
104
+ tokens.append(data["token"].strip())
105
+ if isinstance(data.get("tokens"), list):
106
+ tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
107
+
108
+ if not tokens:
109
+ raise HTTPException(status_code=400, detail="No tokens provided")
110
+
111
+ unique_tokens = list(dict.fromkeys(tokens))
112
+
113
+ raw_results = await UsageService.batch(
114
+ unique_tokens,
115
+ mgr,
116
+ )
117
+
118
+ results = {}
119
+ for token, res in raw_results.items():
120
+ if res.get("ok"):
121
+ results[token] = res.get("data", False)
122
+ else:
123
+ results[token] = False
124
+
125
+ response = {"status": "success", "results": results}
126
+ return response
127
+ except Exception as e:
128
+ raise HTTPException(status_code=500, detail=str(e))
129
+
130
+
131
+ @router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)])
132
+ async def refresh_tokens_async(data: dict):
133
+ """刷新 Token 状态(异步批量 + SSE 进度)"""
134
+ mgr = await get_token_manager()
135
+ tokens = []
136
+ if isinstance(data.get("token"), str) and data["token"].strip():
137
+ tokens.append(data["token"].strip())
138
+ if isinstance(data.get("tokens"), list):
139
+ tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
140
+
141
+ if not tokens:
142
+ raise HTTPException(status_code=400, detail="No tokens provided")
143
+
144
+ unique_tokens = list(dict.fromkeys(tokens))
145
+
146
+ task = create_task(len(unique_tokens))
147
+
148
+ async def _run():
149
+ try:
150
+
151
+ async def _on_item(item: str, res: dict):
152
+ task.record(bool(res.get("ok")))
153
+
154
+ raw_results = await UsageService.batch(
155
+ unique_tokens,
156
+ mgr,
157
+ on_item=_on_item,
158
+ should_cancel=lambda: task.cancelled,
159
+ )
160
+
161
+ if task.cancelled:
162
+ task.finish_cancelled()
163
+ return
164
+
165
+ results: dict[str, bool] = {}
166
+ ok_count = 0
167
+ fail_count = 0
168
+ for token, res in raw_results.items():
169
+ if res.get("ok") and res.get("data") is True:
170
+ ok_count += 1
171
+ results[token] = True
172
+ else:
173
+ fail_count += 1
174
+ results[token] = False
175
+
176
+ await mgr._save(force=True)
177
+
178
+ result = {
179
+ "status": "success",
180
+ "summary": {
181
+ "total": len(unique_tokens),
182
+ "ok": ok_count,
183
+ "fail": fail_count,
184
+ },
185
+ "results": results,
186
+ }
187
+ task.finish(result)
188
+ except Exception as e:
189
+ task.fail_task(str(e))
190
+ finally:
191
+ import asyncio
192
+ asyncio.create_task(expire_task(task.id, 300))
193
+
194
+ import asyncio
195
+ asyncio.create_task(_run())
196
+
197
+ return {
198
+ "status": "success",
199
+ "task_id": task.id,
200
+ "total": len(unique_tokens),
201
+ }
202
+
203
+
204
+ @router.get("/batch/{task_id}/stream")
205
+ async def batch_stream(task_id: str, request: Request):
206
+ app_key = get_app_key()
207
+ if app_key:
208
+ key = request.query_params.get("app_key")
209
+ if key != app_key:
210
+ raise HTTPException(status_code=401, detail="Invalid authentication token")
211
+ task = get_task(task_id)
212
+ if not task:
213
+ raise HTTPException(status_code=404, detail="Task not found")
214
+
215
+ async def event_stream():
216
+ queue = task.attach()
217
+ try:
218
+ yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n"
219
+
220
+ final = task.final_event()
221
+ if final:
222
+ yield f"data: {orjson.dumps(final).decode()}\n\n"
223
+ return
224
+
225
+ while True:
226
+ try:
227
+ event = await asyncio.wait_for(queue.get(), timeout=15)
228
+ except asyncio.TimeoutError:
229
+ yield ": ping\n\n"
230
+ final = task.final_event()
231
+ if final:
232
+ yield f"data: {orjson.dumps(final).decode()}\n\n"
233
+ return
234
+ continue
235
+
236
+ yield f"data: {orjson.dumps(event).decode()}\n\n"
237
+ if event.get("type") in ("done", "error", "cancelled"):
238
+ return
239
+ finally:
240
+ task.detach(queue)
241
+
242
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
243
+
244
+
245
+ @router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)])
246
+ async def batch_cancel(task_id: str):
247
+ task = get_task(task_id)
248
+ if not task:
249
+ raise HTTPException(status_code=404, detail="Task not found")
250
+ task.cancel()
251
+ return {"status": "success"}
252
+
253
+
254
+ @router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)])
255
+ async def enable_nsfw(data: dict):
256
+ """批量开启 NSFW (Unhinged) 模式"""
257
+ try:
258
+ mgr = await get_token_manager()
259
+
260
+ tokens = []
261
+ if isinstance(data.get("token"), str) and data["token"].strip():
262
+ tokens.append(data["token"].strip())
263
+ if isinstance(data.get("tokens"), list):
264
+ tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
265
+
266
+ if not tokens:
267
+ for pool_name, pool in mgr.pools.items():
268
+ for info in pool.list():
269
+ raw = (
270
+ info.token[4:] if info.token.startswith("sso=") else info.token
271
+ )
272
+ tokens.append(raw)
273
+
274
+ if not tokens:
275
+ raise HTTPException(status_code=400, detail="No tokens available")
276
+
277
+ unique_tokens = list(dict.fromkeys(tokens))
278
+
279
+ raw_results = await NSFWService.batch(
280
+ unique_tokens,
281
+ mgr,
282
+ )
283
+
284
+ results = {}
285
+ ok_count = 0
286
+ fail_count = 0
287
+
288
+ for token, res in raw_results.items():
289
+ masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token
290
+ if res.get("ok") and res.get("data", {}).get("success"):
291
+ ok_count += 1
292
+ results[masked] = res.get("data", {})
293
+ else:
294
+ fail_count += 1
295
+ results[masked] = res.get("data") or {"error": res.get("error")}
296
+
297
+ response = {
298
+ "status": "success",
299
+ "summary": {
300
+ "total": len(unique_tokens),
301
+ "ok": ok_count,
302
+ "fail": fail_count,
303
+ },
304
+ "results": results,
305
+ }
306
+
307
+ return response
308
+
309
+ except HTTPException:
310
+ raise
311
+ except Exception as e:
312
+ logger.error(f"Enable NSFW failed: {e}")
313
+ raise HTTPException(status_code=500, detail=str(e))
314
+
315
+
316
+ @router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)])
317
+ async def enable_nsfw_async(data: dict):
318
+ """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)"""
319
+ mgr = await get_token_manager()
320
+
321
+ tokens = []
322
+ if isinstance(data.get("token"), str) and data["token"].strip():
323
+ tokens.append(data["token"].strip())
324
+ if isinstance(data.get("tokens"), list):
325
+ tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
326
+
327
+ if not tokens:
328
+ for pool_name, pool in mgr.pools.items():
329
+ for info in pool.list():
330
+ raw = info.token[4:] if info.token.startswith("sso=") else info.token
331
+ tokens.append(raw)
332
+
333
+ if not tokens:
334
+ raise HTTPException(status_code=400, detail="No tokens available")
335
+
336
+ unique_tokens = list(dict.fromkeys(tokens))
337
+
338
+ task = create_task(len(unique_tokens))
339
+
340
+ async def _run():
341
+ try:
342
+
343
+ async def _on_item(item: str, res: dict):
344
+ ok = bool(res.get("ok") and res.get("data", {}).get("success"))
345
+ task.record(ok)
346
+
347
+ raw_results = await NSFWService.batch(
348
+ unique_tokens,
349
+ mgr,
350
+ on_item=_on_item,
351
+ should_cancel=lambda: task.cancelled,
352
+ )
353
+
354
+ if task.cancelled:
355
+ task.finish_cancelled()
356
+ return
357
+
358
+ results = {}
359
+ ok_count = 0
360
+ fail_count = 0
361
+ for token, res in raw_results.items():
362
+ masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token
363
+ if res.get("ok") and res.get("data", {}).get("success"):
364
+ ok_count += 1
365
+ results[masked] = res.get("data", {})
366
+ else:
367
+ fail_count += 1
368
+ results[masked] = res.get("data") or {"error": res.get("error")}
369
+
370
+ await mgr._save(force=True)
371
+
372
+ result = {
373
+ "status": "success",
374
+ "summary": {
375
+ "total": len(unique_tokens),
376
+ "ok": ok_count,
377
+ "fail": fail_count,
378
+ },
379
+ "results": results,
380
+ }
381
+ task.finish(result)
382
+ except Exception as e:
383
+ task.fail_task(str(e))
384
+ finally:
385
+ import asyncio
386
+ asyncio.create_task(expire_task(task.id, 300))
387
+
388
+ import asyncio
389
+ asyncio.create_task(_run())
390
+
391
+ return {
392
+ "status": "success",
393
+ "task_id": task.id,
394
+ "total": len(unique_tokens),
395
+ }
app/api/v1/chat.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat Completions API 路由
3
+ """
4
+
5
+ from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union
6
+ import base64
7
+ import binascii
8
+ import time
9
+ import uuid
10
+
11
+ from fastapi import APIRouter
12
+ from fastapi.responses import StreamingResponse, JSONResponse
13
+ from pydantic import BaseModel, Field
14
+ import orjson
15
+
16
+ from app.services.grok.services.chat import ChatService
17
+ from app.services.grok.services.image import ImageGenerationService
18
+ from app.services.grok.services.image_edit import ImageEditService
19
+ from app.services.grok.services.model import ModelService
20
+ from app.services.grok.services.video import VideoService
21
+ from app.services.grok.utils.response import make_chat_response
22
+ from app.services.token import get_token_manager
23
+ from app.core.config import get_config
24
+ from app.core.exceptions import ValidationException, AppException, ErrorType
25
+
26
+
27
+ class MessageItem(BaseModel):
28
+ """消息项"""
29
+
30
+ role: str
31
+ content: Optional[Union[str, Dict[str, Any], List[Dict[str, Any]]]]
32
+ tool_calls: Optional[List[Dict[str, Any]]] = None
33
+ tool_call_id: Optional[str] = None
34
+ name: Optional[str] = None
35
+
36
+
37
+ class VideoConfig(BaseModel):
38
+ """视频生成配置"""
39
+
40
+ aspect_ratio: Optional[str] = Field("3:2", description="视频比例: 1280x720(16:9), 720x1280(9:16), 1792x1024(3:2), 1024x1792(2:3), 1024x1024(1:1)")
41
+ video_length: Optional[int] = Field(6, description="视频时长(秒): 6 / 10 / 15")
42
+ resolution_name: Optional[str] = Field("480p", description="视频分辨率: 480p, 720p")
43
+ preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy")
44
+
45
+
46
+ class ImageConfig(BaseModel):
47
+ """图片生成配置"""
48
+
49
+ n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
50
+ size: Optional[str] = Field("1024x1024", description="图片尺寸")
51
+ response_format: Optional[str] = Field(None, description="响应格式")
52
+
53
+
54
+ class ChatCompletionRequest(BaseModel):
55
+ """Chat Completions 请求"""
56
+
57
+ model: str = Field(..., description="模型名称")
58
+ messages: List[MessageItem] = Field(..., description="消息数组")
59
+ stream: Optional[bool] = Field(None, description="是否流式输出")
60
+ reasoning_effort: Optional[str] = Field(None, description="推理强度: none/minimal/low/medium/high/xhigh")
61
+ temperature: Optional[float] = Field(0.8, description="采样温度: 0-2")
62
+ top_p: Optional[float] = Field(0.95, description="nucleus 采样: 0-1")
63
+ # 视频生成配置
64
+ video_config: Optional[VideoConfig] = Field(None, description="视频生成参数")
65
+ # 图片生成配置
66
+ image_config: Optional[ImageConfig] = Field(None, description="图片生成参数")
67
+ # Tool calling
68
+ tools: Optional[List[Dict[str, Any]]] = Field(None, description="Tool definitions")
69
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice: auto/required/none/specific")
70
+ parallel_tool_calls: Optional[bool] = Field(True, description="Allow parallel tool calls")
71
+
72
+
73
+ VALID_ROLES = {"developer", "system", "user", "assistant", "tool"}
74
+ USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"}
75
+ ALLOWED_IMAGE_SIZES = {
76
+ "1280x720",
77
+ "720x1280",
78
+ "1792x1024",
79
+ "1024x1792",
80
+ "1024x1024",
81
+ }
82
+ IMAGINE_FAST_MODEL_ID = "grok-imagine-1.0-fast"
83
+
84
+
85
+ def _validate_media_input(value: str, field_name: str, param: str):
86
+ """Verify media input is a valid URL or data URI"""
87
+ if not isinstance(value, str) or not value.strip():
88
+ raise ValidationException(
89
+ message=f"{field_name} cannot be empty",
90
+ param=param,
91
+ code="empty_media",
92
+ )
93
+ value = value.strip()
94
+ if value.startswith("data:"):
95
+ return
96
+ if value.startswith("http://") or value.startswith("https://"):
97
+ return
98
+ candidate = "".join(value.split())
99
+ if len(candidate) >= 32 and len(candidate) % 4 == 0:
100
+ try:
101
+ base64.b64decode(candidate, validate=True)
102
+ raise ValidationException(
103
+ message=f"{field_name} base64 must be provided as a data URI (data:<mime>;base64,...)",
104
+ param=param,
105
+ code="invalid_media",
106
+ )
107
+ except binascii.Error:
108
+ pass
109
+ raise ValidationException(
110
+ message=f"{field_name} must be a URL or data URI",
111
+ param=param,
112
+ code="invalid_media",
113
+ )
114
+
115
+
116
+ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]:
117
+ """Extract prompt text and image URLs from messages"""
118
+ last_text = ""
119
+ image_urls: List[str] = []
120
+
121
+ for msg in messages:
122
+ role = msg.role or "user"
123
+ content = msg.content
124
+ if isinstance(content, str):
125
+ text = content.strip()
126
+ if text:
127
+ last_text = text
128
+ continue
129
+ if isinstance(content, dict):
130
+ content = [content]
131
+ if not isinstance(content, list):
132
+ continue
133
+ for block in content:
134
+ if not isinstance(block, dict):
135
+ continue
136
+ block_type = block.get("type")
137
+ if block_type == "text":
138
+ text = block.get("text", "")
139
+ if isinstance(text, str) and text.strip():
140
+ last_text = text.strip()
141
+ elif block_type == "image_url" and role == "user":
142
+ image = block.get("image_url") or {}
143
+ url = image.get("url", "")
144
+ if isinstance(url, str) and url.strip():
145
+ image_urls.append(url.strip())
146
+
147
+ return last_text, image_urls
148
+
149
+
150
+ def _resolve_image_format(value: Optional[str]) -> str:
151
+ fmt = value or get_config("app.image_format") or "url"
152
+ if isinstance(fmt, str):
153
+ fmt = fmt.lower()
154
+ if fmt == "base64":
155
+ return "b64_json"
156
+ if fmt in ("b64_json", "url"):
157
+ return fmt
158
+ raise ValidationException(
159
+ message="image_format must be one of url, base64, b64_json",
160
+ param="image_format",
161
+ code="invalid_image_format",
162
+ )
163
+
164
+
165
+ def _image_field(response_format: str) -> str:
166
+ if response_format == "url":
167
+ return "url"
168
+ return "b64_json"
169
+
170
+
171
+ def _imagine_fast_server_image_config() -> ImageConfig:
172
+ """Load server-side image generation parameters for grok-imagine-1.0-fast."""
173
+ n = int(get_config("imagine_fast.n", 1) or 1)
174
+ size = str(get_config("imagine_fast.size", "1024x1024") or "1024x1024")
175
+ response_format = str(
176
+ get_config("imagine_fast.response_format", get_config("app.image_format") or "url")
177
+ or "url"
178
+ )
179
+ return ImageConfig(n=n, size=size, response_format=response_format)
180
+
181
+
182
+ async def _safe_sse_stream(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
183
+ """Ensure streaming endpoints return SSE error payloads instead of transport-level 5xx breaks."""
184
+ try:
185
+ async for chunk in stream:
186
+ yield chunk
187
+ except AppException as e:
188
+ payload = {
189
+ "error": {
190
+ "message": e.message,
191
+ "type": e.error_type,
192
+ "code": e.code,
193
+ }
194
+ }
195
+ yield f"event: error\ndata: {orjson.dumps(payload).decode()}\n\n"
196
+ yield "data: [DONE]\n\n"
197
+ except Exception as e:
198
+ payload = {
199
+ "error": {
200
+ "message": str(e) or "stream_error",
201
+ "type": "server_error",
202
+ "code": "stream_error",
203
+ }
204
+ }
205
+ yield f"event: error\ndata: {orjson.dumps(payload).decode()}\n\n"
206
+ yield "data: [DONE]\n\n"
207
+
208
+
209
+ def _streaming_error_response(exc: Exception) -> StreamingResponse:
210
+ if isinstance(exc, AppException):
211
+ payload = {
212
+ "error": {
213
+ "message": exc.message,
214
+ "type": exc.error_type,
215
+ "code": exc.code,
216
+ }
217
+ }
218
+ else:
219
+ payload = {
220
+ "error": {
221
+ "message": str(exc) or "stream_error",
222
+ "type": "server_error",
223
+ "code": "stream_error",
224
+ }
225
+ }
226
+
227
+ async def _one_shot_error() -> AsyncGenerator[str, None]:
228
+ yield f"event: error\ndata: {orjson.dumps(payload).decode()}\n\n"
229
+ yield "data: [DONE]\n\n"
230
+
231
+ return StreamingResponse(
232
+ _one_shot_error(),
233
+ media_type="text/event-stream",
234
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
235
+ )
236
+
237
+ def _validate_image_config(image_conf: ImageConfig, *, stream: bool):
238
+ n = image_conf.n or 1
239
+ if n < 1 or n > 10:
240
+ raise ValidationException(
241
+ message="n must be between 1 and 10",
242
+ param="image_config.n",
243
+ code="invalid_n",
244
+ )
245
+ if stream and n not in (1, 2):
246
+ raise ValidationException(
247
+ message="Streaming is only supported when n=1 or n=2",
248
+ param="image_config.n",
249
+ code="invalid_stream_n",
250
+ )
251
+ if image_conf.response_format:
252
+ allowed_formats = {"b64_json", "base64", "url"}
253
+ if image_conf.response_format not in allowed_formats:
254
+ raise ValidationException(
255
+ message="response_format must be one of b64_json, base64, url",
256
+ param="image_config.response_format",
257
+ code="invalid_response_format",
258
+ )
259
+ if image_conf.size and image_conf.size not in ALLOWED_IMAGE_SIZES:
260
+ raise ValidationException(
261
+ message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}",
262
+ param="image_config.size",
263
+ code="invalid_size",
264
+ )
265
+ def validate_request(request: ChatCompletionRequest):
266
+ """验证请求参数"""
267
+ # 验证模型
268
+ if not ModelService.valid(request.model):
269
+ raise ValidationException(
270
+ message=f"The model `{request.model}` does not exist or you do not have access to it.",
271
+ param="model",
272
+ code="model_not_found",
273
+ )
274
+
275
+ # 验证消息
276
+ for idx, msg in enumerate(request.messages):
277
+ if not isinstance(msg.role, str) or msg.role not in VALID_ROLES:
278
+ raise ValidationException(
279
+ message=f"role must be one of {sorted(VALID_ROLES)}",
280
+ param=f"messages.{idx}.role",
281
+ code="invalid_role",
282
+ )
283
+
284
+ # tool role: requires tool_call_id, content can be None/empty
285
+ if msg.role == "tool":
286
+ if not msg.tool_call_id:
287
+ raise ValidationException(
288
+ message="tool messages must have a 'tool_call_id' field",
289
+ param=f"messages.{idx}.tool_call_id",
290
+ code="missing_tool_call_id",
291
+ )
292
+ continue
293
+
294
+ # assistant with tool_calls: content can be None
295
+ if msg.role == "assistant" and msg.tool_calls:
296
+ continue
297
+
298
+ content = msg.content
299
+
300
+ # 兼容部分客户端会发送 assistant/tool 空内容(例如工具调用中间态)
301
+ if content is None:
302
+ if msg.role in {"assistant", "tool"}:
303
+ continue
304
+ raise ValidationException(
305
+ message="Message content cannot be null",
306
+ param=f"messages.{idx}.content",
307
+ code="empty_content",
308
+ )
309
+
310
+ # 字符串内容
311
+ if isinstance(content, str):
312
+ if not content.strip():
313
+ raise ValidationException(
314
+ message="Message content cannot be empty",
315
+ param=f"messages.{idx}.content",
316
+ code="empty_content",
317
+ )
318
+
319
+ # 列表内容
320
+ elif isinstance(content, dict):
321
+ content = [content]
322
+ for c_idx, item in enumerate(content):
323
+ if not isinstance(item, dict):
324
+ raise ValidationException(
325
+ message="Message content items must be objects",
326
+ param=f"messages.{idx}.content.{c_idx}",
327
+ code="invalid_content_item",
328
+ )
329
+ item_type = item.get("type")
330
+ if item_type != "text":
331
+ raise ValidationException(
332
+ message="When content is an object, type must be 'text'",
333
+ param=f"messages.{idx}.content.{c_idx}.type",
334
+ code="invalid_content_type",
335
+ )
336
+ text = item.get("text", "")
337
+ if not isinstance(text, str) or not text.strip():
338
+ raise ValidationException(
339
+ message="messages.%d.content.%d.text must be a non-empty string"
340
+ % (idx, c_idx),
341
+ param=f"messages.{idx}.content.{c_idx}.text",
342
+ code="empty_content",
343
+ )
344
+
345
+ elif isinstance(content, list):
346
+ if not content:
347
+ raise ValidationException(
348
+ message="Message content cannot be an empty array",
349
+ param=f"messages.{idx}.content",
350
+ code="empty_content",
351
+ )
352
+
353
+ for block_idx, block in enumerate(content):
354
+ # 检查空对象
355
+ if not isinstance(block, dict):
356
+ raise ValidationException(
357
+ message="Content block must be an object",
358
+ param=f"messages.{idx}.content.{block_idx}",
359
+ code="invalid_block",
360
+ )
361
+ if not block:
362
+ raise ValidationException(
363
+ message="Content block cannot be empty",
364
+ param=f"messages.{idx}.content.{block_idx}",
365
+ code="empty_block",
366
+ )
367
+
368
+ # 检查 type 字段
369
+ if "type" not in block:
370
+ raise ValidationException(
371
+ message="Content block must have a 'type' field",
372
+ param=f"messages.{idx}.content.{block_idx}",
373
+ code="missing_type",
374
+ )
375
+
376
+ block_type = block.get("type")
377
+
378
+ # 检查 type 空值
379
+ if (
380
+ not block_type
381
+ or not isinstance(block_type, str)
382
+ or not block_type.strip()
383
+ ):
384
+ raise ValidationException(
385
+ message="Content block 'type' cannot be empty",
386
+ param=f"messages.{idx}.content.{block_idx}.type",
387
+ code="empty_type",
388
+ )
389
+
390
+ # 验证 type 有效性
391
+ if msg.role == "user":
392
+ if block_type not in USER_CONTENT_TYPES:
393
+ raise ValidationException(
394
+ message=f"Invalid content block type: '{block_type}'",
395
+ param=f"messages.{idx}.content.{block_idx}.type",
396
+ code="invalid_type",
397
+ )
398
+ else:
399
+ if block_type != "text":
400
+ raise ValidationException(
401
+ message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'",
402
+ param=f"messages.{idx}.content.{block_idx}.type",
403
+ code="invalid_type",
404
+ )
405
+
406
+ # 验证字段是否存在 & 非空
407
+ if block_type == "text":
408
+ text = block.get("text", "")
409
+ if not isinstance(text, str) or not text.strip():
410
+ raise ValidationException(
411
+ message="Text content cannot be empty",
412
+ param=f"messages.{idx}.content.{block_idx}.text",
413
+ code="empty_text",
414
+ )
415
+ elif block_type == "image_url":
416
+ image_url = block.get("image_url")
417
+ if not image_url or not isinstance(image_url, dict):
418
+ raise ValidationException(
419
+ message="image_url must have a 'url' field",
420
+ param=f"messages.{idx}.content.{block_idx}.image_url",
421
+ code="missing_url",
422
+ )
423
+ _validate_media_input(
424
+ image_url.get("url", ""),
425
+ "image_url.url",
426
+ f"messages.{idx}.content.{block_idx}.image_url.url",
427
+ )
428
+ elif block_type == "input_audio":
429
+ audio = block.get("input_audio")
430
+ if not audio or not isinstance(audio, dict):
431
+ raise ValidationException(
432
+ message="input_audio must have a 'data' field",
433
+ param=f"messages.{idx}.content.{block_idx}.input_audio",
434
+ code="missing_audio",
435
+ )
436
+ _validate_media_input(
437
+ audio.get("data", ""),
438
+ "input_audio.data",
439
+ f"messages.{idx}.content.{block_idx}.input_audio.data",
440
+ )
441
+ elif block_type == "file":
442
+ file_data = block.get("file")
443
+ if not file_data or not isinstance(file_data, dict):
444
+ raise ValidationException(
445
+ message="file must have a 'file_data' field",
446
+ param=f"messages.{idx}.content.{block_idx}.file",
447
+ code="missing_file",
448
+ )
449
+ _validate_media_input(
450
+ file_data.get("file_data", ""),
451
+ "file.file_data",
452
+ f"messages.{idx}.content.{block_idx}.file.file_data",
453
+ )
454
+ elif content is None:
455
+ raise ValidationException(
456
+ message="Message content cannot be empty",
457
+ param=f"messages.{idx}.content",
458
+ code="empty_content",
459
+ )
460
+ else:
461
+ raise ValidationException(
462
+ message="Message content must be a string or array",
463
+ param=f"messages.{idx}.content",
464
+ code="invalid_content",
465
+ )
466
+
467
+ # 默认验证
468
+ if request.stream is not None:
469
+ if isinstance(request.stream, bool):
470
+ pass
471
+ elif isinstance(request.stream, str):
472
+ if request.stream.lower() in ("true", "1", "yes"):
473
+ request.stream = True
474
+ elif request.stream.lower() in ("false", "0", "no"):
475
+ request.stream = False
476
+ else:
477
+ raise ValidationException(
478
+ message="stream must be a boolean",
479
+ param="stream",
480
+ code="invalid_stream",
481
+ )
482
+ else:
483
+ raise ValidationException(
484
+ message="stream must be a boolean",
485
+ param="stream",
486
+ code="invalid_stream",
487
+ )
488
+
489
+ allowed_efforts = {"none", "minimal", "low", "medium", "high", "xhigh"}
490
+ if request.reasoning_effort is not None:
491
+ if not isinstance(request.reasoning_effort, str) or (
492
+ request.reasoning_effort not in allowed_efforts
493
+ ):
494
+ raise ValidationException(
495
+ message=f"reasoning_effort must be one of {sorted(allowed_efforts)}",
496
+ param="reasoning_effort",
497
+ code="invalid_reasoning_effort",
498
+ )
499
+
500
+ if request.temperature is None:
501
+ request.temperature = 0.8
502
+ else:
503
+ try:
504
+ request.temperature = float(request.temperature)
505
+ except Exception:
506
+ raise ValidationException(
507
+ message="temperature must be a float",
508
+ param="temperature",
509
+ code="invalid_temperature",
510
+ )
511
+ if not (0 <= request.temperature <= 2):
512
+ raise ValidationException(
513
+ message="temperature must be between 0 and 2",
514
+ param="temperature",
515
+ code="invalid_temperature",
516
+ )
517
+
518
+ if request.top_p is None:
519
+ request.top_p = 0.95
520
+ else:
521
+ try:
522
+ request.top_p = float(request.top_p)
523
+ except Exception:
524
+ raise ValidationException(
525
+ message="top_p must be a float",
526
+ param="top_p",
527
+ code="invalid_top_p",
528
+ )
529
+ if not (0 <= request.top_p <= 1):
530
+ raise ValidationException(
531
+ message="top_p must be between 0 and 1",
532
+ param="top_p",
533
+ code="invalid_top_p",
534
+ )
535
+
536
+ # 验证 tools
537
+ if request.tools is not None:
538
+ if not isinstance(request.tools, list):
539
+ raise ValidationException(
540
+ message="tools must be an array",
541
+ param="tools",
542
+ code="invalid_tools",
543
+ )
544
+ for t_idx, tool in enumerate(request.tools):
545
+ if not isinstance(tool, dict) or tool.get("type") != "function":
546
+ raise ValidationException(
547
+ message="Each tool must have type='function'",
548
+ param=f"tools.{t_idx}.type",
549
+ code="invalid_tool_type",
550
+ )
551
+ func = tool.get("function")
552
+ if not isinstance(func, dict) or not func.get("name"):
553
+ raise ValidationException(
554
+ message="Each tool function must have a 'name'",
555
+ param=f"tools.{t_idx}.function.name",
556
+ code="missing_function_name",
557
+ )
558
+
559
+ # 验证 tool_choice
560
+ if request.tool_choice is not None:
561
+ if isinstance(request.tool_choice, str):
562
+ if request.tool_choice not in ("auto", "required", "none"):
563
+ raise ValidationException(
564
+ message="tool_choice must be 'auto', 'required', 'none', or a specific function object",
565
+ param="tool_choice",
566
+ code="invalid_tool_choice",
567
+ )
568
+ elif isinstance(request.tool_choice, dict):
569
+ if request.tool_choice.get("type") != "function" or not request.tool_choice.get("function", {}).get("name"):
570
+ raise ValidationException(
571
+ message="tool_choice object must have type='function' and function.name",
572
+ param="tool_choice",
573
+ code="invalid_tool_choice",
574
+ )
575
+
576
+ model_info = ModelService.get(request.model)
577
+ # image 验证
578
+ if model_info and (model_info.is_image or model_info.is_image_edit):
579
+ prompt, image_urls = _extract_prompt_images(request.messages)
580
+ if not prompt:
581
+ raise ValidationException(
582
+ message="Prompt cannot be empty",
583
+ param="messages",
584
+ code="empty_prompt",
585
+ )
586
+ image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig())
587
+ n = image_conf.n or 1
588
+ if not (1 <= n <= 10):
589
+ raise ValidationException(
590
+ message="n must be between 1 and 10",
591
+ param="image_config.n",
592
+ code="invalid_n",
593
+ )
594
+ if request.stream and n not in (1, 2):
595
+ raise ValidationException(
596
+ message="Streaming is only supported when n=1 or n=2",
597
+ param="stream",
598
+ code="invalid_stream_n",
599
+ )
600
+
601
+ response_format = _resolve_image_format(image_conf.response_format)
602
+ image_conf.n = n
603
+ image_conf.response_format = response_format
604
+ if not image_conf.size:
605
+ image_conf.size = "1024x1024"
606
+ allowed_sizes = {
607
+ "1280x720",
608
+ "720x1280",
609
+ "1792x1024",
610
+ "1024x1792",
611
+ "1024x1024",
612
+ }
613
+ if image_conf.size not in allowed_sizes:
614
+ raise ValidationException(
615
+ message=f"size must be one of {sorted(allowed_sizes)}",
616
+ param="image_config.size",
617
+ code="invalid_size",
618
+ )
619
+ request.image_config = image_conf
620
+
621
+ # image edit 验证
622
+ if model_info and model_info.is_image_edit:
623
+ _, image_urls = _extract_prompt_images(request.messages)
624
+ if not image_urls:
625
+ raise ValidationException(
626
+ message="image_url is required for image edits",
627
+ param="messages",
628
+ code="missing_image",
629
+ )
630
+
631
+ # video 验证
632
+ if model_info and model_info.is_video:
633
+ config = request.video_config or VideoConfig()
634
+ ratio_map = {
635
+ "1280x720": "16:9",
636
+ "720x1280": "9:16",
637
+ "1792x1024": "3:2",
638
+ "1024x1792": "2:3",
639
+ "1024x1024": "1:1",
640
+ "16:9": "16:9",
641
+ "9:16": "9:16",
642
+ "3:2": "3:2",
643
+ "2:3": "2:3",
644
+ "1:1": "1:1",
645
+ }
646
+ if config.aspect_ratio is None:
647
+ config.aspect_ratio = "3:2"
648
+ if config.aspect_ratio not in ratio_map:
649
+ raise ValidationException(
650
+ message=f"aspect_ratio must be one of {list(ratio_map.keys())}",
651
+ param="video_config.aspect_ratio",
652
+ code="invalid_aspect_ratio",
653
+ )
654
+ config.aspect_ratio = ratio_map[config.aspect_ratio]
655
+
656
+ if config.video_length not in (6, 10, 15):
657
+ raise ValidationException(
658
+ message="video_length must be 6, 10, or 15 seconds",
659
+ param="video_config.video_length",
660
+ code="invalid_video_length",
661
+ )
662
+ if config.resolution_name not in ("480p", "720p"):
663
+ raise ValidationException(
664
+ message="resolution_name must be one of ['480p', '720p']",
665
+ param="video_config.resolution_name",
666
+ code="invalid_resolution",
667
+ )
668
+ if config.preset not in ("fun", "normal", "spicy", "custom"):
669
+ raise ValidationException(
670
+ message="preset must be one of ['fun', 'normal', 'spicy', 'custom']",
671
+ param="video_config.preset",
672
+ code="invalid_preset",
673
+ )
674
+ request.video_config = config
675
+
676
+
677
+ router = APIRouter(tags=["Chat"])
678
+
679
+
680
+ @router.post("/chat/completions")
681
+ async def chat_completions(request: ChatCompletionRequest):
682
+ """Chat Completions API - 兼容 OpenAI"""
683
+ from app.core.logger import logger
684
+
685
+ # 参数验证
686
+ validate_request(request)
687
+
688
+ logger.debug(f"Chat request: model={request.model}, stream={request.stream}")
689
+
690
+ # 检测模型类型
691
+ model_info = ModelService.get(request.model)
692
+ if model_info and model_info.is_image_edit:
693
+ prompt, image_urls = _extract_prompt_images(request.messages)
694
+ if not image_urls:
695
+ raise ValidationException(
696
+ message="Image is required",
697
+ param="image",
698
+ code="missing_image",
699
+ )
700
+
701
+ is_stream = (
702
+ request.stream if request.stream is not None else get_config("app.stream")
703
+ )
704
+ image_conf = request.image_config or ImageConfig()
705
+ _validate_image_config(image_conf, stream=bool(is_stream))
706
+ response_format = _resolve_image_format(image_conf.response_format)
707
+ response_field = _image_field(response_format)
708
+ n = image_conf.n or 1
709
+
710
+ token_mgr = await get_token_manager()
711
+ await token_mgr.reload_if_stale()
712
+
713
+ token = None
714
+ for pool_name in ModelService.pool_candidates_for_model(request.model):
715
+ token = token_mgr.get_token(pool_name)
716
+ if token:
717
+ break
718
+
719
+ if not token:
720
+ raise AppException(
721
+ message="No available tokens. Please try again later.",
722
+ error_type=ErrorType.RATE_LIMIT.value,
723
+ code="rate_limit_exceeded",
724
+ status_code=429,
725
+ )
726
+
727
+ result = await ImageEditService().edit(
728
+ token_mgr=token_mgr,
729
+ token=token,
730
+ model_info=model_info,
731
+ prompt=prompt,
732
+ images=image_urls,
733
+ n=n,
734
+ response_format=response_format,
735
+ stream=bool(is_stream),
736
+ chat_format=True,
737
+ )
738
+
739
+ if result.stream:
740
+ return StreamingResponse(
741
+ _safe_sse_stream(result.data),
742
+ media_type="text/event-stream",
743
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
744
+ )
745
+
746
+ content = result.data[0] if result.data else ""
747
+ return JSONResponse(
748
+ content=make_chat_response(request.model, content)
749
+ )
750
+
751
+ if model_info and model_info.is_image:
752
+ prompt, _ = _extract_prompt_images(request.messages)
753
+
754
+ is_stream = (
755
+ request.stream if request.stream is not None else get_config("app.stream")
756
+ )
757
+ image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig())
758
+ _validate_image_config(image_conf, stream=bool(is_stream))
759
+ response_format = _resolve_image_format(image_conf.response_format)
760
+ response_field = _image_field(response_format)
761
+ n = image_conf.n or 1
762
+ size = image_conf.size or "1024x1024"
763
+ aspect_ratio_map = {
764
+ "1280x720": "16:9",
765
+ "720x1280": "9:16",
766
+ "1792x1024": "3:2",
767
+ "1024x1792": "2:3",
768
+ "1024x1024": "1:1",
769
+ }
770
+ aspect_ratio = aspect_ratio_map.get(size, "2:3")
771
+
772
+ token_mgr = await get_token_manager()
773
+ await token_mgr.reload_if_stale()
774
+
775
+ token = None
776
+ for pool_name in ModelService.pool_candidates_for_model(request.model):
777
+ token = token_mgr.get_token(pool_name)
778
+ if token:
779
+ break
780
+
781
+ if not token:
782
+ raise AppException(
783
+ message="No available tokens. Please try again later.",
784
+ error_type=ErrorType.RATE_LIMIT.value,
785
+ code="rate_limit_exceeded",
786
+ status_code=429,
787
+ )
788
+
789
+ result = await ImageGenerationService().generate(
790
+ token_mgr=token_mgr,
791
+ token=token,
792
+ model_info=model_info,
793
+ prompt=prompt,
794
+ n=n,
795
+ response_format=response_format,
796
+ size=size,
797
+ aspect_ratio=aspect_ratio,
798
+ stream=bool(is_stream),
799
+ chat_format=True,
800
+ )
801
+
802
+ if result.stream:
803
+ return StreamingResponse(
804
+ _safe_sse_stream(result.data),
805
+ media_type="text/event-stream",
806
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
807
+ )
808
+
809
+ content = result.data[0] if result.data else ""
810
+ usage = result.usage_override
811
+ return JSONResponse(
812
+ content=make_chat_response(request.model, content, usage=usage)
813
+ )
814
+
815
+ if model_info and model_info.is_video:
816
+ # 提取视频配置 (默认值在 Pydantic 模型中处理)
817
+ v_conf = request.video_config or VideoConfig()
818
+
819
+ try:
820
+ result = await VideoService.completions(
821
+ model=request.model,
822
+ messages=[msg.model_dump() for msg in request.messages],
823
+ stream=request.stream,
824
+ reasoning_effort=request.reasoning_effort,
825
+ aspect_ratio=v_conf.aspect_ratio,
826
+ video_length=v_conf.video_length,
827
+ resolution=v_conf.resolution_name,
828
+ preset=v_conf.preset,
829
+ )
830
+ except Exception as e:
831
+ if request.stream is not False:
832
+ return _streaming_error_response(e)
833
+ raise
834
+ else:
835
+ try:
836
+ result = await ChatService.completions(
837
+ model=request.model,
838
+ messages=[msg.model_dump() for msg in request.messages],
839
+ stream=request.stream,
840
+ reasoning_effort=request.reasoning_effort,
841
+ temperature=request.temperature,
842
+ top_p=request.top_p,
843
+ tools=request.tools,
844
+ tool_choice=request.tool_choice,
845
+ parallel_tool_calls=request.parallel_tool_calls,
846
+ )
847
+ except Exception as e:
848
+ if request.stream is not False:
849
+ return _streaming_error_response(e)
850
+ raise
851
+
852
+ if isinstance(result, dict):
853
+ return JSONResponse(content=result)
854
+ else:
855
+ return StreamingResponse(
856
+ _safe_sse_stream(result),
857
+ media_type="text/event-stream",
858
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
859
+ )
860
+
861
+
862
+ __all__ = ["router"]
app/api/v1/files.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件服务 API 路由
3
+ """
4
+
5
+ import aiofiles.os
6
+ from pathlib import Path
7
+ from fastapi import APIRouter, HTTPException
8
+ from fastapi.responses import FileResponse
9
+
10
+ from app.core.logger import logger
11
+ from app.core.storage import DATA_DIR
12
+
13
+ router = APIRouter(tags=["Files"])
14
+
15
+ # 缓存根目录
16
+ BASE_DIR = DATA_DIR / "tmp"
17
+ IMAGE_DIR = BASE_DIR / "image"
18
+ VIDEO_DIR = BASE_DIR / "video"
19
+
20
+
21
+ @router.get("/image/{filename:path}")
22
+ async def get_image(filename: str):
23
+ """
24
+ 获取图片文件
25
+ """
26
+ if "/" in filename:
27
+ filename = filename.replace("/", "-")
28
+
29
+ file_path = IMAGE_DIR / filename
30
+
31
+ if await aiofiles.os.path.exists(file_path):
32
+ if await aiofiles.os.path.isfile(file_path):
33
+ content_type = "image/jpeg"
34
+ if file_path.suffix.lower() == ".png":
35
+ content_type = "image/png"
36
+ elif file_path.suffix.lower() == ".webp":
37
+ content_type = "image/webp"
38
+
39
+ # 增加缓存头,支持高并发场景下的浏览器/CDN缓存
40
+ return FileResponse(
41
+ file_path,
42
+ media_type=content_type,
43
+ headers={"Cache-Control": "public, max-age=31536000, immutable"},
44
+ )
45
+
46
+ logger.warning(f"Image not found: {filename}")
47
+ raise HTTPException(status_code=404, detail="Image not found")
48
+
49
+
50
+ @router.get("/video/{filename:path}")
51
+ async def get_video(filename: str):
52
+ """
53
+ 获取视频文件
54
+ """
55
+ if "/" in filename:
56
+ filename = filename.replace("/", "-")
57
+
58
+ file_path = VIDEO_DIR / filename
59
+
60
+ if await aiofiles.os.path.exists(file_path):
61
+ if await aiofiles.os.path.isfile(file_path):
62
+ return FileResponse(
63
+ file_path,
64
+ media_type="video/mp4",
65
+ headers={"Cache-Control": "public, max-age=31536000, immutable"},
66
+ )
67
+
68
+ logger.warning(f"Video not found: {filename}")
69
+ raise HTTPException(status_code=404, detail="Video not found")
app/api/v1/image.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Generation API 路由
3
+ """
4
+
5
+ import base64
6
+ import time
7
+ from pathlib import Path
8
+ from typing import List, Optional, Union
9
+
10
+ from fastapi import APIRouter, File, Form, UploadFile
11
+ from fastapi.responses import StreamingResponse, JSONResponse
12
+ from pydantic import BaseModel, Field, ValidationError
13
+
14
+ from app.services.grok.services.image import ImageGenerationService
15
+ from app.services.grok.services.image_edit import ImageEditService
16
+ from app.services.grok.services.model import ModelService
17
+ from app.services.token import get_token_manager
18
+ from app.core.exceptions import ValidationException, AppException, ErrorType
19
+ from app.core.config import get_config
20
+
21
+
22
+ router = APIRouter(tags=["Images"])
23
+
24
+ ALLOWED_IMAGE_SIZES = {
25
+ "1280x720",
26
+ "720x1280",
27
+ "1792x1024",
28
+ "1024x1792",
29
+ "1024x1024",
30
+ }
31
+
32
+ SIZE_TO_ASPECT = {
33
+ "1280x720": "16:9",
34
+ "720x1280": "9:16",
35
+ "1792x1024": "3:2",
36
+ "1024x1792": "2:3",
37
+ "1024x1024": "1:1",
38
+ }
39
+ ALLOWED_ASPECT_RATIOS = {"1:1", "2:3", "3:2", "9:16", "16:9"}
40
+
41
+
42
+ class ImageGenerationRequest(BaseModel):
43
+ """图片生成请求 - OpenAI 兼容"""
44
+
45
+ prompt: str = Field(..., description="图片描述")
46
+ model: Optional[str] = Field("grok-imagine-1.0", description="模型名称")
47
+ n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
48
+ size: Optional[str] = Field(
49
+ "1024x1024",
50
+ description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024",
51
+ )
52
+ quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)")
53
+ response_format: Optional[str] = Field(None, description="响应格式")
54
+ style: Optional[str] = Field(None, description="风格 (暂不支持)")
55
+ stream: Optional[bool] = Field(False, description="是否流式输出")
56
+
57
+
58
+ class ImageEditRequest(BaseModel):
59
+ """图片编辑请求 - OpenAI 兼容"""
60
+
61
+ prompt: str = Field(..., description="编辑描述")
62
+ model: Optional[str] = Field("grok-imagine-1.0-edit", description="模型名称")
63
+ image: Optional[Union[str, List[str]]] = Field(None, description="待编辑图片文件")
64
+ n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
65
+ size: Optional[str] = Field(
66
+ "1024x1024",
67
+ description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024",
68
+ )
69
+ quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)")
70
+ response_format: Optional[str] = Field(None, description="响应格式")
71
+ style: Optional[str] = Field(None, description="风格 (暂不支持)")
72
+ stream: Optional[bool] = Field(False, description="是否流式输出")
73
+
74
+
75
+ def _validate_common_request(
76
+ request: Union[ImageGenerationRequest, ImageEditRequest],
77
+ *,
78
+ allow_ws_stream: bool = False,
79
+ ):
80
+ """通用参数校验"""
81
+ # 验证 prompt
82
+ if not request.prompt or not request.prompt.strip():
83
+ raise ValidationException(
84
+ message="Prompt cannot be empty", param="prompt", code="empty_prompt"
85
+ )
86
+
87
+ # 验证 n 参数范围
88
+ if request.n < 1 or request.n > 10:
89
+ raise ValidationException(
90
+ message="n must be between 1 and 10", param="n", code="invalid_n"
91
+ )
92
+
93
+ # 流式只支持 n=1 或 n=2
94
+ if request.stream and request.n not in [1, 2]:
95
+ raise ValidationException(
96
+ message="Streaming is only supported when n=1 or n=2",
97
+ param="stream",
98
+ code="invalid_stream_n",
99
+ )
100
+
101
+ if allow_ws_stream:
102
+ if request.stream and request.response_format:
103
+ allowed_stream_formats = {"b64_json", "base64", "url"}
104
+ if request.response_format not in allowed_stream_formats:
105
+ raise ValidationException(
106
+ message="Streaming only supports response_format=b64_json/base64/url",
107
+ param="response_format",
108
+ code="invalid_response_format",
109
+ )
110
+
111
+ if request.response_format:
112
+ allowed_formats = {"b64_json", "base64", "url"}
113
+ if request.response_format not in allowed_formats:
114
+ raise ValidationException(
115
+ message=f"response_format must be one of {sorted(allowed_formats)}",
116
+ param="response_format",
117
+ code="invalid_response_format",
118
+ )
119
+
120
+ if request.size and request.size not in ALLOWED_IMAGE_SIZES:
121
+ raise ValidationException(
122
+ message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}",
123
+ param="size",
124
+ code="invalid_size",
125
+ )
126
+
127
+
128
+ def validate_generation_request(request: ImageGenerationRequest):
129
+ """验证图片生成请求参数"""
130
+ if request.model != "grok-imagine-1.0":
131
+ raise ValidationException(
132
+ message="The model `grok-imagine-1.0` is required for image generation.",
133
+ param="model",
134
+ code="model_not_supported",
135
+ )
136
+ # 验证模型 - 通过 is_image 检查
137
+ model_info = ModelService.get(request.model)
138
+ if not model_info or not model_info.is_image:
139
+ # 获取支持的图片模型列表
140
+ image_models = [m.model_id for m in ModelService.MODELS if m.is_image]
141
+ raise ValidationException(
142
+ message=(
143
+ f"The model `{request.model}` is not supported for image generation. "
144
+ f"Supported: {image_models}"
145
+ ),
146
+ param="model",
147
+ code="model_not_supported",
148
+ )
149
+ _validate_common_request(request, allow_ws_stream=True)
150
+
151
+
152
+ def resolve_response_format(response_format: Optional[str]) -> str:
153
+ """解析响应格式"""
154
+ fmt = response_format or get_config("app.image_format")
155
+ if isinstance(fmt, str):
156
+ fmt = fmt.lower()
157
+ if fmt in ("b64_json", "base64", "url"):
158
+ return fmt
159
+ raise ValidationException(
160
+ message="response_format must be one of b64_json, base64, url",
161
+ param="response_format",
162
+ code="invalid_response_format",
163
+ )
164
+
165
+
166
+ def response_field_name(response_format: str) -> str:
167
+ """获取响应字段名"""
168
+ return {"url": "url", "base64": "base64"}.get(response_format, "b64_json")
169
+
170
+
171
+ def resolve_aspect_ratio(size: str) -> str:
172
+ """Map OpenAI size to Grok Imagine aspect ratio."""
173
+ value = (size or "").strip()
174
+ if not value:
175
+ return "2:3"
176
+ if value in SIZE_TO_ASPECT:
177
+ return SIZE_TO_ASPECT[value]
178
+ if ":" in value:
179
+ try:
180
+ left, right = value.split(":", 1)
181
+ left_i = int(left.strip())
182
+ right_i = int(right.strip())
183
+ if left_i > 0 and right_i > 0:
184
+ ratio = f"{left_i}:{right_i}"
185
+ if ratio in ALLOWED_ASPECT_RATIOS:
186
+ return ratio
187
+ except (TypeError, ValueError):
188
+ pass
189
+ return "2:3"
190
+
191
+
192
+ def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]):
193
+ """验证图片编辑请求参数"""
194
+ if request.model != "grok-imagine-1.0-edit":
195
+ raise ValidationException(
196
+ message=("The model `grok-imagine-1.0-edit` is required for image edits."),
197
+ param="model",
198
+ code="model_not_supported",
199
+ )
200
+ model_info = ModelService.get(request.model)
201
+ if not model_info or not model_info.is_image_edit:
202
+ edit_models = [m.model_id for m in ModelService.MODELS if m.is_image_edit]
203
+ raise ValidationException(
204
+ message=(
205
+ f"The model `{request.model}` is not supported for image edits. "
206
+ f"Supported: {edit_models}"
207
+ ),
208
+ param="model",
209
+ code="model_not_supported",
210
+ )
211
+ _validate_common_request(request, allow_ws_stream=False)
212
+ if not images:
213
+ raise ValidationException(
214
+ message="Image is required",
215
+ param="image",
216
+ code="missing_image",
217
+ )
218
+ if len(images) > 16:
219
+ raise ValidationException(
220
+ message="Too many images. Maximum is 16.",
221
+ param="image",
222
+ code="invalid_image_count",
223
+ )
224
+
225
+
226
+ async def _get_token(model: str):
227
+ """获取可用 token"""
228
+ token_mgr = await get_token_manager()
229
+ await token_mgr.reload_if_stale()
230
+
231
+ token = None
232
+ for pool_name in ModelService.pool_candidates_for_model(model):
233
+ token = token_mgr.get_token(pool_name)
234
+ if token:
235
+ break
236
+
237
+ if not token:
238
+ raise AppException(
239
+ message="No available tokens. Please try again later.",
240
+ error_type=ErrorType.RATE_LIMIT.value,
241
+ code="rate_limit_exceeded",
242
+ status_code=429,
243
+ )
244
+
245
+ return token_mgr, token
246
+
247
+
248
+ @router.post("/images/generations")
249
+ async def create_image(request: ImageGenerationRequest):
250
+ """
251
+ Image Generation API
252
+
253
+ 流式响应格式:
254
+ - event: image_generation.partial_image
255
+ - event: image_generation.completed
256
+
257
+ 非流式响应格式:
258
+ - {"created": ..., "data": [{"b64_json": "..."}], "usage": {...}}
259
+ """
260
+ # stream 默认为 false
261
+ if request.stream is None:
262
+ request.stream = False
263
+
264
+ if request.response_format is None:
265
+ request.response_format = resolve_response_format(None)
266
+
267
+ # 参数验证
268
+ validate_generation_request(request)
269
+
270
+ # 兼容 base64/b64_json
271
+ if request.response_format == "base64":
272
+ request.response_format = "b64_json"
273
+
274
+ response_format = resolve_response_format(request.response_format)
275
+ response_field = response_field_name(response_format)
276
+
277
+ # 获取 token 和模型信息
278
+ token_mgr, token = await _get_token(request.model)
279
+ model_info = ModelService.get(request.model)
280
+ aspect_ratio = resolve_aspect_ratio(request.size)
281
+
282
+ result = await ImageGenerationService().generate(
283
+ token_mgr=token_mgr,
284
+ token=token,
285
+ model_info=model_info,
286
+ prompt=request.prompt,
287
+ n=request.n,
288
+ response_format=response_format,
289
+ size=request.size,
290
+ aspect_ratio=aspect_ratio,
291
+ stream=bool(request.stream),
292
+ )
293
+
294
+ if result.stream:
295
+ return StreamingResponse(
296
+ result.data,
297
+ media_type="text/event-stream",
298
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
299
+ )
300
+
301
+ data = [{response_field: img} for img in result.data]
302
+ usage = result.usage_override or {
303
+ "total_tokens": 0,
304
+ "input_tokens": 0,
305
+ "output_tokens": 0,
306
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
307
+ }
308
+
309
+ return JSONResponse(
310
+ content={
311
+ "created": int(time.time()),
312
+ "data": data,
313
+ "usage": usage,
314
+ }
315
+ )
316
+
317
+
318
+ @router.post("/images/edits")
319
+ async def edit_image(
320
+ prompt: str = Form(...),
321
+ image: List[UploadFile] = File(...),
322
+ model: Optional[str] = Form("grok-imagine-1.0-edit"),
323
+ n: int = Form(1),
324
+ size: str = Form("1024x1024"),
325
+ quality: str = Form("standard"),
326
+ response_format: Optional[str] = Form(None),
327
+ style: Optional[str] = Form(None),
328
+ stream: Optional[bool] = Form(False),
329
+ ):
330
+ """
331
+ Image Edits API
332
+
333
+ 同官方 API 格式,仅支持 multipart/form-data 文件上传
334
+ """
335
+ if response_format is None:
336
+ response_format = resolve_response_format(None)
337
+
338
+ try:
339
+ edit_request = ImageEditRequest(
340
+ prompt=prompt,
341
+ model=model,
342
+ n=n,
343
+ size=size,
344
+ quality=quality,
345
+ response_format=response_format,
346
+ style=style,
347
+ stream=stream,
348
+ )
349
+ except ValidationError as exc:
350
+ errors = exc.errors()
351
+ if errors:
352
+ first = errors[0]
353
+ loc = first.get("loc", [])
354
+ msg = first.get("msg", "Invalid request")
355
+ code = first.get("type", "invalid_value")
356
+ param_parts = [
357
+ str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())
358
+ ]
359
+ param = ".".join(param_parts) if param_parts else None
360
+ raise ValidationException(message=msg, param=param, code=code)
361
+ raise ValidationException(message="Invalid request", code="invalid_value")
362
+
363
+ if edit_request.stream is None:
364
+ edit_request.stream = False
365
+
366
+ response_format = resolve_response_format(edit_request.response_format)
367
+ if response_format == "base64":
368
+ response_format = "b64_json"
369
+ edit_request.response_format = response_format
370
+ response_field = response_field_name(response_format)
371
+
372
+ # 参数验证
373
+ validate_edit_request(edit_request, image)
374
+
375
+ max_image_bytes = 50 * 1024 * 1024
376
+ allowed_types = {"image/png", "image/jpeg", "image/webp", "image/jpg"}
377
+
378
+ images: List[str] = []
379
+ for item in image:
380
+ content = await item.read()
381
+ await item.close()
382
+ if not content:
383
+ raise ValidationException(
384
+ message="File content is empty",
385
+ param="image",
386
+ code="empty_file",
387
+ )
388
+ if len(content) > max_image_bytes:
389
+ raise ValidationException(
390
+ message="Image file too large. Maximum is 50MB.",
391
+ param="image",
392
+ code="file_too_large",
393
+ )
394
+ mime = (item.content_type or "").lower()
395
+ if mime == "image/jpg":
396
+ mime = "image/jpeg"
397
+ ext = Path(item.filename or "").suffix.lower()
398
+ if mime not in allowed_types:
399
+ if ext in (".jpg", ".jpeg"):
400
+ mime = "image/jpeg"
401
+ elif ext == ".png":
402
+ mime = "image/png"
403
+ elif ext == ".webp":
404
+ mime = "image/webp"
405
+ else:
406
+ raise ValidationException(
407
+ message="Unsupported image type. Supported: png, jpg, webp.",
408
+ param="image",
409
+ code="invalid_image_type",
410
+ )
411
+ b64 = base64.b64encode(content).decode()
412
+ images.append(f"data:{mime};base64,{b64}")
413
+
414
+ # 获取 token 和模型信息
415
+ token_mgr, token = await _get_token(edit_request.model)
416
+ model_info = ModelService.get(edit_request.model)
417
+
418
+ result = await ImageEditService().edit(
419
+ token_mgr=token_mgr,
420
+ token=token,
421
+ model_info=model_info,
422
+ prompt=edit_request.prompt,
423
+ images=images,
424
+ n=edit_request.n,
425
+ response_format=response_format,
426
+ stream=bool(edit_request.stream),
427
+ )
428
+
429
+ if result.stream:
430
+ return StreamingResponse(
431
+ result.data,
432
+ media_type="text/event-stream",
433
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
434
+ )
435
+
436
+ data = [{response_field: img} for img in result.data]
437
+
438
+ return JSONResponse(
439
+ content={
440
+ "created": int(time.time()),
441
+ "data": data,
442
+ "usage": {
443
+ "total_tokens": 0,
444
+ "input_tokens": 0,
445
+ "output_tokens": 0,
446
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
447
+ },
448
+ }
449
+ )
450
+
451
+
452
+ __all__ = ["router"]
app/api/v1/models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Models API 路由
3
+ """
4
+
5
+ from fastapi import APIRouter
6
+
7
+ from app.services.grok.services.model import ModelService
8
+
9
+
10
+ router = APIRouter(tags=["Models"])
11
+
12
+
13
+ @router.get("/models")
14
+ async def list_models():
15
+ """OpenAI 兼容 models 列表接口"""
16
+ data = [
17
+ {
18
+ "id": m.model_id,
19
+ "object": "model",
20
+ "created": 0,
21
+ "owned_by": "grok2api@chenyme",
22
+ }
23
+ for m in ModelService.list()
24
+ ]
25
+ return {"object": "list", "data": data}
26
+
27
+
28
+ __all__ = ["router"]
app/api/v1/public_api/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public API router (public_key protected)."""
2
+
3
+ from fastapi import APIRouter, Depends
4
+
5
+ from app.api.v1.chat import router as chat_router
6
+ from app.api.v1.public_api.imagine import router as imagine_router
7
+ from app.api.v1.public_api.video import router as video_router
8
+ from app.api.v1.public_api.voice import router as voice_router
9
+ from app.core.auth import verify_public_key
10
+
11
+ router = APIRouter()
12
+
13
+ router.include_router(chat_router, dependencies=[Depends(verify_public_key)])
14
+ router.include_router(imagine_router)
15
+ router.include_router(video_router)
16
+ router.include_router(voice_router)
17
+
18
+ __all__ = ["router"]
app/api/v1/public_api/imagine.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import uuid
4
+ from typing import Optional, List, Dict, Any
5
+
6
+ import orjson
7
+ from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
8
+ from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel
10
+
11
+ from app.core.auth import verify_public_key, get_public_api_key, is_public_enabled
12
+ from app.core.config import get_config
13
+ from app.core.logger import logger
14
+ from app.api.v1.image import resolve_aspect_ratio
15
+ from app.services.grok.services.image import ImageGenerationService
16
+ from app.services.grok.services.model import ModelService
17
+ from app.services.token.manager import get_token_manager
18
+
19
+ router = APIRouter()
20
+
21
+ IMAGINE_SESSION_TTL = 600
22
+ _IMAGINE_SESSIONS: dict[str, dict] = {}
23
+ _IMAGINE_SESSIONS_LOCK = asyncio.Lock()
24
+
25
+
26
+ async def _clean_sessions(now: float) -> None:
27
+ expired = [
28
+ key
29
+ for key, info in _IMAGINE_SESSIONS.items()
30
+ if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL
31
+ ]
32
+ for key in expired:
33
+ _IMAGINE_SESSIONS.pop(key, None)
34
+
35
+
36
+ def _parse_sse_chunk(chunk: str) -> Optional[Dict[str, Any]]:
37
+ if not chunk:
38
+ return None
39
+ event = None
40
+ data_lines: List[str] = []
41
+ for raw in str(chunk).splitlines():
42
+ line = raw.strip()
43
+ if not line:
44
+ continue
45
+ if line.startswith("event:"):
46
+ event = line[6:].strip()
47
+ continue
48
+ if line.startswith("data:"):
49
+ data_lines.append(line[5:].strip())
50
+ if not data_lines:
51
+ return None
52
+ data_str = "\n".join(data_lines)
53
+ if data_str == "[DONE]":
54
+ return None
55
+ try:
56
+ payload = orjson.loads(data_str)
57
+ except orjson.JSONDecodeError:
58
+ return None
59
+ if event and isinstance(payload, dict) and "type" not in payload:
60
+ payload["type"] = event
61
+ return payload
62
+
63
+
64
+ async def _new_session(prompt: str, aspect_ratio: str, nsfw: Optional[bool]) -> str:
65
+ task_id = uuid.uuid4().hex
66
+ now = time.time()
67
+ async with _IMAGINE_SESSIONS_LOCK:
68
+ await _clean_sessions(now)
69
+ _IMAGINE_SESSIONS[task_id] = {
70
+ "prompt": prompt,
71
+ "aspect_ratio": aspect_ratio,
72
+ "nsfw": nsfw,
73
+ "created_at": now,
74
+ }
75
+ return task_id
76
+
77
+
78
+ async def _get_session(task_id: str) -> Optional[dict]:
79
+ if not task_id:
80
+ return None
81
+ now = time.time()
82
+ async with _IMAGINE_SESSIONS_LOCK:
83
+ await _clean_sessions(now)
84
+ info = _IMAGINE_SESSIONS.get(task_id)
85
+ if not info:
86
+ return None
87
+ created_at = float(info.get("created_at") or 0)
88
+ if now - created_at > IMAGINE_SESSION_TTL:
89
+ _IMAGINE_SESSIONS.pop(task_id, None)
90
+ return None
91
+ return dict(info)
92
+
93
+
94
+ async def _drop_session(task_id: str) -> None:
95
+ if not task_id:
96
+ return
97
+ async with _IMAGINE_SESSIONS_LOCK:
98
+ _IMAGINE_SESSIONS.pop(task_id, None)
99
+
100
+
101
+ async def _drop_sessions(task_ids: List[str]) -> int:
102
+ if not task_ids:
103
+ return 0
104
+ removed = 0
105
+ async with _IMAGINE_SESSIONS_LOCK:
106
+ for task_id in task_ids:
107
+ if task_id and task_id in _IMAGINE_SESSIONS:
108
+ _IMAGINE_SESSIONS.pop(task_id, None)
109
+ removed += 1
110
+ return removed
111
+
112
+
113
+ @router.websocket("/imagine/ws")
114
+ async def public_imagine_ws(websocket: WebSocket):
115
+ session_id = None
116
+ task_id = websocket.query_params.get("task_id")
117
+ if task_id:
118
+ info = await _get_session(task_id)
119
+ if info:
120
+ session_id = task_id
121
+
122
+ ok = True
123
+ if session_id is None:
124
+ public_key = get_public_api_key()
125
+ public_enabled = is_public_enabled()
126
+ if not public_key:
127
+ ok = public_enabled
128
+ else:
129
+ key = websocket.query_params.get("public_key")
130
+ ok = key == public_key
131
+
132
+ if not ok:
133
+ await websocket.close(code=1008)
134
+ return
135
+
136
+ await websocket.accept()
137
+ stop_event = asyncio.Event()
138
+ run_task: Optional[asyncio.Task] = None
139
+
140
+ async def _send(payload: dict) -> bool:
141
+ try:
142
+ await websocket.send_text(orjson.dumps(payload).decode())
143
+ return True
144
+ except Exception:
145
+ return False
146
+
147
+ async def _stop_run():
148
+ nonlocal run_task
149
+ stop_event.set()
150
+ if run_task and not run_task.done():
151
+ run_task.cancel()
152
+ try:
153
+ await run_task
154
+ except Exception:
155
+ pass
156
+ run_task = None
157
+ stop_event.clear()
158
+
159
+ async def _run(prompt: str, aspect_ratio: str, nsfw: Optional[bool]):
160
+ model_id = "grok-imagine-1.0"
161
+ model_info = ModelService.get(model_id)
162
+ if not model_info or not model_info.is_image:
163
+ await _send(
164
+ {
165
+ "type": "error",
166
+ "message": "Image model is not available.",
167
+ "code": "model_not_supported",
168
+ }
169
+ )
170
+ return
171
+
172
+ token_mgr = await get_token_manager()
173
+ run_id = uuid.uuid4().hex
174
+
175
+ await _send(
176
+ {
177
+ "type": "status",
178
+ "status": "running",
179
+ "prompt": prompt,
180
+ "aspect_ratio": aspect_ratio,
181
+ "run_id": run_id,
182
+ }
183
+ )
184
+
185
+ while not stop_event.is_set():
186
+ try:
187
+ await token_mgr.reload_if_stale()
188
+ token = None
189
+ for pool_name in ModelService.pool_candidates_for_model(
190
+ model_info.model_id
191
+ ):
192
+ token = token_mgr.get_token(pool_name)
193
+ if token:
194
+ break
195
+
196
+ if not token:
197
+ await _send(
198
+ {
199
+ "type": "error",
200
+ "message": "No available tokens. Please try again later.",
201
+ "code": "rate_limit_exceeded",
202
+ }
203
+ )
204
+ await asyncio.sleep(2)
205
+ continue
206
+
207
+ result = await ImageGenerationService().generate(
208
+ token_mgr=token_mgr,
209
+ token=token,
210
+ model_info=model_info,
211
+ prompt=prompt,
212
+ n=6,
213
+ response_format="b64_json",
214
+ size="1024x1024",
215
+ aspect_ratio=aspect_ratio,
216
+ stream=True,
217
+ enable_nsfw=nsfw,
218
+ )
219
+ if result.stream:
220
+ async for chunk in result.data:
221
+ payload = _parse_sse_chunk(chunk)
222
+ if not payload:
223
+ continue
224
+ if isinstance(payload, dict):
225
+ payload.setdefault("run_id", run_id)
226
+ await _send(payload)
227
+ else:
228
+ images = [img for img in result.data if img and img != "error"]
229
+ if images:
230
+ for img_b64 in images:
231
+ await _send(
232
+ {
233
+ "type": "image",
234
+ "b64_json": img_b64,
235
+ "created_at": int(time.time() * 1000),
236
+ "aspect_ratio": aspect_ratio,
237
+ "run_id": run_id,
238
+ }
239
+ )
240
+ else:
241
+ await _send(
242
+ {
243
+ "type": "error",
244
+ "message": "Image generation returned empty data.",
245
+ "code": "empty_image",
246
+ }
247
+ )
248
+
249
+ except asyncio.CancelledError:
250
+ break
251
+ except Exception as e:
252
+ logger.warning(f"Imagine stream error: {e}")
253
+ await _send(
254
+ {
255
+ "type": "error",
256
+ "message": str(e),
257
+ "code": "internal_error",
258
+ }
259
+ )
260
+ await asyncio.sleep(1.5)
261
+
262
+ await _send({"type": "status", "status": "stopped", "run_id": run_id})
263
+
264
+ try:
265
+ while True:
266
+ try:
267
+ raw = await websocket.receive_text()
268
+ except (RuntimeError, WebSocketDisconnect):
269
+ break
270
+
271
+ try:
272
+ payload = orjson.loads(raw)
273
+ except Exception:
274
+ await _send(
275
+ {
276
+ "type": "error",
277
+ "message": "Invalid message format.",
278
+ "code": "invalid_payload",
279
+ }
280
+ )
281
+ continue
282
+
283
+ action = payload.get("type")
284
+ if action == "start":
285
+ prompt = str(payload.get("prompt") or "").strip()
286
+ if not prompt:
287
+ await _send(
288
+ {
289
+ "type": "error",
290
+ "message": "Prompt cannot be empty.",
291
+ "code": "invalid_prompt",
292
+ }
293
+ )
294
+ continue
295
+ aspect_ratio = resolve_aspect_ratio(
296
+ str(payload.get("aspect_ratio") or "2:3").strip() or "2:3"
297
+ )
298
+ nsfw = payload.get("nsfw")
299
+ if nsfw is not None:
300
+ nsfw = bool(nsfw)
301
+ await _stop_run()
302
+ run_task = asyncio.create_task(_run(prompt, aspect_ratio, nsfw))
303
+ elif action == "stop":
304
+ await _stop_run()
305
+ else:
306
+ await _send(
307
+ {
308
+ "type": "error",
309
+ "message": "Unknown action.",
310
+ "code": "invalid_action",
311
+ }
312
+ )
313
+
314
+ except WebSocketDisconnect:
315
+ logger.debug("WebSocket disconnected by client")
316
+ except Exception as e:
317
+ logger.warning(f"WebSocket error: {e}")
318
+ finally:
319
+ await _stop_run()
320
+
321
+ try:
322
+ from starlette.websockets import WebSocketState
323
+ if websocket.client_state == WebSocketState.CONNECTED:
324
+ await websocket.close(code=1000, reason="Server closing connection")
325
+ except Exception as e:
326
+ logger.debug(f"WebSocket close ignored: {e}")
327
+ if session_id:
328
+ await _drop_session(session_id)
329
+
330
+
331
+ @router.get("/imagine/sse")
332
+ async def public_imagine_sse(
333
+ request: Request,
334
+ task_id: str = Query(""),
335
+ prompt: str = Query(""),
336
+ aspect_ratio: str = Query("2:3"),
337
+ ):
338
+ """Imagine 图片瀑布流(SSE 兜底)"""
339
+ session = None
340
+ if task_id:
341
+ session = await _get_session(task_id)
342
+ if not session:
343
+ raise HTTPException(status_code=404, detail="Task not found")
344
+ else:
345
+ public_key = get_public_api_key()
346
+ public_enabled = is_public_enabled()
347
+ if not public_key:
348
+ if not public_enabled:
349
+ raise HTTPException(status_code=401, detail="Public access is disabled")
350
+ else:
351
+ key = request.query_params.get("public_key")
352
+ if key != public_key:
353
+ raise HTTPException(status_code=401, detail="Invalid authentication token")
354
+
355
+ if session:
356
+ prompt = str(session.get("prompt") or "").strip()
357
+ ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3"
358
+ nsfw = session.get("nsfw")
359
+ else:
360
+ prompt = (prompt or "").strip()
361
+ if not prompt:
362
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
363
+ ratio = str(aspect_ratio or "2:3").strip() or "2:3"
364
+ ratio = resolve_aspect_ratio(ratio)
365
+ nsfw = request.query_params.get("nsfw")
366
+ if nsfw is not None:
367
+ nsfw = str(nsfw).lower() in ("1", "true", "yes", "on")
368
+
369
+ async def event_stream():
370
+ try:
371
+ model_id = "grok-imagine-1.0"
372
+ model_info = ModelService.get(model_id)
373
+ if not model_info or not model_info.is_image:
374
+ yield (
375
+ f"data: {orjson.dumps({'type': 'error', 'message': 'Image model is not available.', 'code': 'model_not_supported'}).decode()}\n\n"
376
+ )
377
+ return
378
+
379
+ token_mgr = await get_token_manager()
380
+ sequence = 0
381
+ run_id = uuid.uuid4().hex
382
+
383
+ yield (
384
+ f"data: {orjson.dumps({'type': 'status', 'status': 'running', 'prompt': prompt, 'aspect_ratio': ratio, 'run_id': run_id}).decode()}\n\n"
385
+ )
386
+
387
+ while True:
388
+ if await request.is_disconnected():
389
+ break
390
+ if task_id:
391
+ session_alive = await _get_session(task_id)
392
+ if not session_alive:
393
+ break
394
+
395
+ try:
396
+ await token_mgr.reload_if_stale()
397
+ token = None
398
+ for pool_name in ModelService.pool_candidates_for_model(
399
+ model_info.model_id
400
+ ):
401
+ token = token_mgr.get_token(pool_name)
402
+ if token:
403
+ break
404
+
405
+ if not token:
406
+ yield (
407
+ f"data: {orjson.dumps({'type': 'error', 'message': 'No available tokens. Please try again later.', 'code': 'rate_limit_exceeded'}).decode()}\n\n"
408
+ )
409
+ await asyncio.sleep(2)
410
+ continue
411
+
412
+ result = await ImageGenerationService().generate(
413
+ token_mgr=token_mgr,
414
+ token=token,
415
+ model_info=model_info,
416
+ prompt=prompt,
417
+ n=6,
418
+ response_format="b64_json",
419
+ size="1024x1024",
420
+ aspect_ratio=ratio,
421
+ stream=True,
422
+ enable_nsfw=nsfw,
423
+ )
424
+ if result.stream:
425
+ async for chunk in result.data:
426
+ payload = _parse_sse_chunk(chunk)
427
+ if not payload:
428
+ continue
429
+ if isinstance(payload, dict):
430
+ payload.setdefault("run_id", run_id)
431
+ yield f"data: {orjson.dumps(payload).decode()}\n\n"
432
+ else:
433
+ images = [img for img in result.data if img and img != "error"]
434
+ if images:
435
+ for img_b64 in images:
436
+ sequence += 1
437
+ payload = {
438
+ "type": "image",
439
+ "b64_json": img_b64,
440
+ "sequence": sequence,
441
+ "created_at": int(time.time() * 1000),
442
+ "aspect_ratio": ratio,
443
+ "run_id": run_id,
444
+ }
445
+ yield f"data: {orjson.dumps(payload).decode()}\n\n"
446
+ else:
447
+ yield (
448
+ f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n"
449
+ )
450
+ except asyncio.CancelledError:
451
+ break
452
+ except Exception as e:
453
+ logger.warning(f"Imagine SSE error: {e}")
454
+ yield (
455
+ f"data: {orjson.dumps({'type': 'error', 'message': str(e), 'code': 'internal_error'}).decode()}\n\n"
456
+ )
457
+ await asyncio.sleep(1.5)
458
+
459
+ yield (
460
+ f"data: {orjson.dumps({'type': 'status', 'status': 'stopped', 'run_id': run_id}).decode()}\n\n"
461
+ )
462
+ finally:
463
+ if task_id:
464
+ await _drop_session(task_id)
465
+
466
+ return StreamingResponse(
467
+ event_stream(),
468
+ media_type="text/event-stream",
469
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
470
+ )
471
+
472
+
473
+ @router.get("/imagine/config")
474
+ async def public_imagine_config():
475
+ return {
476
+ "final_min_bytes": int(get_config("image.final_min_bytes") or 0),
477
+ "medium_min_bytes": int(get_config("image.medium_min_bytes") or 0),
478
+ "nsfw": bool(get_config("image.nsfw")),
479
+ }
480
+
481
+
482
+ class ImagineStartRequest(BaseModel):
483
+ prompt: str
484
+ aspect_ratio: Optional[str] = "2:3"
485
+ nsfw: Optional[bool] = None
486
+
487
+
488
+ @router.post("/imagine/start", dependencies=[Depends(verify_public_key)])
489
+ async def public_imagine_start(data: ImagineStartRequest):
490
+ prompt = (data.prompt or "").strip()
491
+ if not prompt:
492
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
493
+ ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3")
494
+ task_id = await _new_session(prompt, ratio, data.nsfw)
495
+ return {"task_id": task_id, "aspect_ratio": ratio}
496
+
497
+
498
+ class ImagineStopRequest(BaseModel):
499
+ task_ids: List[str]
500
+
501
+
502
+ @router.post("/imagine/stop", dependencies=[Depends(verify_public_key)])
503
+ async def public_imagine_stop(data: ImagineStopRequest):
504
+ removed = await _drop_sessions(data.task_ids or [])
505
+ return {"status": "success", "removed": removed}
app/api/v1/public_api/video.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import uuid
4
+ from typing import Optional, List, Dict, Any
5
+
6
+ import orjson
7
+ from fastapi import APIRouter, Depends, HTTPException, Query, Request
8
+ from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel
10
+
11
+ from app.core.auth import verify_public_key
12
+ from app.core.logger import logger
13
+ from app.services.grok.services.video import VideoService
14
+ from app.services.grok.services.model import ModelService
15
+
16
+ router = APIRouter()
17
+
18
+ VIDEO_SESSION_TTL = 600
19
+ _VIDEO_SESSIONS: dict[str, dict] = {}
20
+ _VIDEO_SESSIONS_LOCK = asyncio.Lock()
21
+
22
+ _VIDEO_RATIO_MAP = {
23
+ "1280x720": "16:9",
24
+ "720x1280": "9:16",
25
+ "1792x1024": "3:2",
26
+ "1024x1792": "2:3",
27
+ "1024x1024": "1:1",
28
+ "16:9": "16:9",
29
+ "9:16": "9:16",
30
+ "3:2": "3:2",
31
+ "2:3": "2:3",
32
+ "1:1": "1:1",
33
+ }
34
+
35
+
36
+ async def _clean_sessions(now: float) -> None:
37
+ expired = [
38
+ key
39
+ for key, info in _VIDEO_SESSIONS.items()
40
+ if now - float(info.get("created_at") or 0) > VIDEO_SESSION_TTL
41
+ ]
42
+ for key in expired:
43
+ _VIDEO_SESSIONS.pop(key, None)
44
+
45
+
46
+ async def _new_session(
47
+ prompt: str,
48
+ aspect_ratio: str,
49
+ video_length: int,
50
+ resolution_name: str,
51
+ preset: str,
52
+ image_url: Optional[str],
53
+ reasoning_effort: Optional[str],
54
+ ) -> str:
55
+ task_id = uuid.uuid4().hex
56
+ now = time.time()
57
+ async with _VIDEO_SESSIONS_LOCK:
58
+ await _clean_sessions(now)
59
+ _VIDEO_SESSIONS[task_id] = {
60
+ "prompt": prompt,
61
+ "aspect_ratio": aspect_ratio,
62
+ "video_length": video_length,
63
+ "resolution_name": resolution_name,
64
+ "preset": preset,
65
+ "image_url": image_url,
66
+ "reasoning_effort": reasoning_effort,
67
+ "created_at": now,
68
+ }
69
+ return task_id
70
+
71
+
72
+ async def _get_session(task_id: str) -> Optional[dict]:
73
+ if not task_id:
74
+ return None
75
+ now = time.time()
76
+ async with _VIDEO_SESSIONS_LOCK:
77
+ await _clean_sessions(now)
78
+ info = _VIDEO_SESSIONS.get(task_id)
79
+ if not info:
80
+ return None
81
+ created_at = float(info.get("created_at") or 0)
82
+ if now - created_at > VIDEO_SESSION_TTL:
83
+ _VIDEO_SESSIONS.pop(task_id, None)
84
+ return None
85
+ return dict(info)
86
+
87
+
88
+ async def _drop_session(task_id: str) -> None:
89
+ if not task_id:
90
+ return
91
+ async with _VIDEO_SESSIONS_LOCK:
92
+ _VIDEO_SESSIONS.pop(task_id, None)
93
+
94
+
95
+ async def _drop_sessions(task_ids: List[str]) -> int:
96
+ if not task_ids:
97
+ return 0
98
+ removed = 0
99
+ async with _VIDEO_SESSIONS_LOCK:
100
+ for task_id in task_ids:
101
+ if task_id and task_id in _VIDEO_SESSIONS:
102
+ _VIDEO_SESSIONS.pop(task_id, None)
103
+ removed += 1
104
+ return removed
105
+
106
+
107
+ def _normalize_ratio(value: Optional[str]) -> str:
108
+ raw = (value or "").strip()
109
+ return _VIDEO_RATIO_MAP.get(raw, "")
110
+
111
+
112
+ def _validate_image_url(image_url: str) -> None:
113
+ value = (image_url or "").strip()
114
+ if not value:
115
+ return
116
+ if value.startswith("data:"):
117
+ return
118
+ if value.startswith("http://") or value.startswith("https://"):
119
+ return
120
+ raise HTTPException(
121
+ status_code=400,
122
+ detail="image_url must be a URL or data URI (data:<mime>;base64,...)",
123
+ )
124
+
125
+
126
+ class VideoStartRequest(BaseModel):
127
+ prompt: str
128
+ aspect_ratio: Optional[str] = "3:2"
129
+ video_length: Optional[int] = 6
130
+ resolution_name: Optional[str] = "480p"
131
+ preset: Optional[str] = "normal"
132
+ image_url: Optional[str] = None
133
+ reasoning_effort: Optional[str] = None
134
+
135
+
136
+ @router.post("/video/start", dependencies=[Depends(verify_public_key)])
137
+ async def public_video_start(data: VideoStartRequest):
138
+ prompt = (data.prompt or "").strip()
139
+ if not prompt:
140
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
141
+
142
+ aspect_ratio = _normalize_ratio(data.aspect_ratio)
143
+ if not aspect_ratio:
144
+ raise HTTPException(
145
+ status_code=400,
146
+ detail="aspect_ratio must be one of ['16:9','9:16','3:2','2:3','1:1']",
147
+ )
148
+
149
+ video_length = int(data.video_length or 6)
150
+ if video_length not in (6, 10, 15):
151
+ raise HTTPException(
152
+ status_code=400, detail="video_length must be 6, 10, or 15 seconds"
153
+ )
154
+
155
+ resolution_name = str(data.resolution_name or "480p")
156
+ if resolution_name not in ("480p", "720p"):
157
+ raise HTTPException(
158
+ status_code=400,
159
+ detail="resolution_name must be one of ['480p','720p']",
160
+ )
161
+
162
+ preset = str(data.preset or "normal")
163
+ if preset not in ("fun", "normal", "spicy", "custom"):
164
+ raise HTTPException(
165
+ status_code=400,
166
+ detail="preset must be one of ['fun','normal','spicy','custom']",
167
+ )
168
+
169
+ image_url = (data.image_url or "").strip() or None
170
+ if image_url:
171
+ _validate_image_url(image_url)
172
+
173
+ reasoning_effort = (data.reasoning_effort or "").strip() or None
174
+ if reasoning_effort:
175
+ allowed = {"none", "minimal", "low", "medium", "high", "xhigh"}
176
+ if reasoning_effort not in allowed:
177
+ raise HTTPException(
178
+ status_code=400,
179
+ detail=f"reasoning_effort must be one of {sorted(allowed)}",
180
+ )
181
+
182
+ task_id = await _new_session(
183
+ prompt,
184
+ aspect_ratio,
185
+ video_length,
186
+ resolution_name,
187
+ preset,
188
+ image_url,
189
+ reasoning_effort,
190
+ )
191
+ return {"task_id": task_id, "aspect_ratio": aspect_ratio}
192
+
193
+
194
+ @router.get("/video/sse")
195
+ async def public_video_sse(request: Request, task_id: str = Query("")):
196
+ session = await _get_session(task_id)
197
+ if not session:
198
+ raise HTTPException(status_code=404, detail="Task not found")
199
+
200
+ prompt = str(session.get("prompt") or "").strip()
201
+ aspect_ratio = str(session.get("aspect_ratio") or "3:2")
202
+ video_length = int(session.get("video_length") or 6)
203
+ resolution_name = str(session.get("resolution_name") or "480p")
204
+ preset = str(session.get("preset") or "normal")
205
+ image_url = session.get("image_url")
206
+ reasoning_effort = session.get("reasoning_effort")
207
+
208
+ async def event_stream():
209
+ try:
210
+ model_id = "grok-imagine-1.0-video"
211
+ model_info = ModelService.get(model_id)
212
+ if not model_info or not model_info.is_video:
213
+ payload = {
214
+ "error": "Video model is not available.",
215
+ "code": "model_not_supported",
216
+ }
217
+ yield f"data: {orjson.dumps(payload).decode()}\n\n"
218
+ yield "data: [DONE]\n\n"
219
+ return
220
+
221
+ if image_url:
222
+ messages: List[Dict[str, Any]] = [
223
+ {
224
+ "role": "user",
225
+ "content": [
226
+ {"type": "text", "text": prompt},
227
+ {"type": "image_url", "image_url": {"url": image_url}},
228
+ ],
229
+ }
230
+ ]
231
+ else:
232
+ messages = [{"role": "user", "content": prompt}]
233
+
234
+ stream = await VideoService.completions(
235
+ model_id,
236
+ messages,
237
+ stream=True,
238
+ reasoning_effort=reasoning_effort,
239
+ aspect_ratio=aspect_ratio,
240
+ video_length=video_length,
241
+ resolution=resolution_name,
242
+ preset=preset,
243
+ )
244
+
245
+ async for chunk in stream:
246
+ if await request.is_disconnected():
247
+ break
248
+ yield chunk
249
+ except Exception as e:
250
+ logger.warning(f"Public video SSE error: {e}")
251
+ payload = {"error": str(e), "code": "internal_error"}
252
+ yield f"data: {orjson.dumps(payload).decode()}\n\n"
253
+ yield "data: [DONE]\n\n"
254
+ finally:
255
+ await _drop_session(task_id)
256
+
257
+ return StreamingResponse(
258
+ event_stream(),
259
+ media_type="text/event-stream",
260
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
261
+ )
262
+
263
+
264
+ class VideoStopRequest(BaseModel):
265
+ task_ids: List[str]
266
+
267
+
268
+ @router.post("/video/stop", dependencies=[Depends(verify_public_key)])
269
+ async def public_video_stop(data: VideoStopRequest):
270
+ removed = await _drop_sessions(data.task_ids or [])
271
+ return {"status": "success", "removed": removed}
272
+
273
+
274
+ __all__ = ["router"]
app/api/v1/public_api/voice.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends
2
+ from pydantic import BaseModel
3
+
4
+ from app.core.auth import verify_public_key
5
+ from app.core.exceptions import AppException
6
+ from app.services.grok.services.voice import VoiceService
7
+ from app.services.token.manager import get_token_manager
8
+
9
+ router = APIRouter()
10
+
11
+
12
+ class VoiceTokenResponse(BaseModel):
13
+ token: str
14
+ url: str
15
+ participant_name: str = ""
16
+ room_name: str = ""
17
+
18
+
19
+ @router.get(
20
+ "/voice/token",
21
+ dependencies=[Depends(verify_public_key)],
22
+ response_model=VoiceTokenResponse,
23
+ )
24
+ async def public_voice_token(
25
+ voice: str = "ara",
26
+ personality: str = "assistant",
27
+ speed: float = 1.0,
28
+ ):
29
+ """获取 Grok Voice Mode (LiveKit) Token"""
30
+ token_mgr = await get_token_manager()
31
+ sso_token = None
32
+ for pool_name in ("ssoBasic", "ssoSuper"):
33
+ sso_token = token_mgr.get_token(pool_name)
34
+ if sso_token:
35
+ break
36
+
37
+ if not sso_token:
38
+ raise AppException(
39
+ "No available tokens for voice mode",
40
+ code="no_token",
41
+ status_code=503,
42
+ )
43
+
44
+ service = VoiceService()
45
+ try:
46
+ data = await service.get_token(
47
+ token=sso_token,
48
+ voice=voice,
49
+ personality=personality,
50
+ speed=speed,
51
+ )
52
+ token = data.get("token")
53
+ if not token:
54
+ raise AppException(
55
+ "Upstream returned no voice token",
56
+ code="upstream_error",
57
+ status_code=502,
58
+ )
59
+
60
+ return VoiceTokenResponse(
61
+ token=token,
62
+ url="wss://livekit.grok.com",
63
+ participant_name="",
64
+ room_name="",
65
+ )
66
+
67
+ except Exception as e:
68
+ if isinstance(e, AppException):
69
+ raise
70
+ raise AppException(
71
+ f"Voice token error: {str(e)}",
72
+ code="voice_error",
73
+ status_code=500,
74
+ )
75
+
76
+
77
+ @router.get("/verify", dependencies=[Depends(verify_public_key)])
78
+ async def public_verify_api():
79
+ """验证 Public Key"""
80
+ return {"status": "success"}
app/api/v1/response.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Responses API 路由 (OpenAI compatible).
3
+ """
4
+
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from fastapi import APIRouter
8
+ from fastapi.responses import JSONResponse, StreamingResponse
9
+ from pydantic import BaseModel, Field
10
+
11
+ from app.core.exceptions import ValidationException
12
+ from app.services.grok.services.responses import ResponsesService
13
+
14
+
15
+ router = APIRouter(tags=["Responses"])
16
+
17
+
18
+ class ResponseCreateRequest(BaseModel):
19
+ model: str = Field(..., description="Model name")
20
+ input: Optional[Any] = Field(None, description="Input content")
21
+ instructions: Optional[str] = Field(None, description="System instructions")
22
+ stream: Optional[bool] = Field(False, description="Stream response")
23
+ max_output_tokens: Optional[int] = Field(None, description="Max output tokens")
24
+ temperature: Optional[float] = Field(None, description="Sampling temperature")
25
+ top_p: Optional[float] = Field(None, description="Nucleus sampling")
26
+ tools: Optional[List[Dict[str, Any]]] = Field(None, description="Tool definitions")
27
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice")
28
+ parallel_tool_calls: Optional[bool] = Field(True, description="Allow parallel tool calls")
29
+ reasoning: Optional[Dict[str, Any]] = Field(None, description="Reasoning options")
30
+ metadata: Optional[Dict[str, Any]] = Field(None, description="Metadata")
31
+ user: Optional[str] = Field(None, description="User identifier")
32
+ store: Optional[bool] = Field(None, description="Store response")
33
+ previous_response_id: Optional[str] = Field(None, description="Previous response id")
34
+ truncation: Optional[str] = Field(None, description="Truncation behavior")
35
+
36
+ class Config:
37
+ extra = "allow"
38
+
39
+
40
+ @router.post("/responses")
41
+ async def create_response(request: ResponseCreateRequest):
42
+ if not request.model:
43
+ raise ValidationException(message="model is required", param="model", code="invalid_request_error")
44
+
45
+ if request.input is None:
46
+ raise ValidationException(message="input is required", param="input", code="invalid_request_error")
47
+
48
+ reasoning_effort = None
49
+ if isinstance(request.reasoning, dict):
50
+ reasoning_effort = request.reasoning.get("effort") or request.reasoning.get("reasoning_effort")
51
+
52
+ result = await ResponsesService.create(
53
+ model=request.model,
54
+ input_value=request.input,
55
+ instructions=request.instructions,
56
+ stream=bool(request.stream),
57
+ temperature=request.temperature,
58
+ top_p=request.top_p,
59
+ tools=request.tools,
60
+ tool_choice=request.tool_choice,
61
+ parallel_tool_calls=request.parallel_tool_calls,
62
+ reasoning_effort=reasoning_effort,
63
+ max_output_tokens=request.max_output_tokens,
64
+ metadata=request.metadata,
65
+ user=request.user,
66
+ store=request.store,
67
+ previous_response_id=request.previous_response_id,
68
+ truncation=request.truncation,
69
+ )
70
+
71
+ if request.stream:
72
+ return StreamingResponse(
73
+ result,
74
+ media_type="text/event-stream",
75
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
76
+ )
77
+
78
+ return JSONResponse(content=result)
79
+
80
+
81
+ __all__ = ["router"]
app/api/v1/video.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ TODO:Video Generation API 路由
3
+ """
app/core/auth.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API 认证模块
3
+ """
4
+
5
+ import hashlib
6
+ from typing import Optional, Iterable
7
+ from fastapi import HTTPException, status, Security
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
+
10
+ from app.core.config import get_config
11
+
12
+ DEFAULT_API_KEY = ""
13
+ DEFAULT_APP_KEY = "grok2api"
14
+ DEFAULT_PUBLIC_KEY = ""
15
+ DEFAULT_PUBLIC_ENABLED = False
16
+
17
+ # 定义 Bearer Scheme
18
+ security = HTTPBearer(
19
+ auto_error=False,
20
+ scheme_name="API Key",
21
+ description="Enter your API Key in the format: Bearer <key>",
22
+ )
23
+
24
+
25
+ def get_admin_api_key() -> str:
26
+ """
27
+ 获取后台 API Key。
28
+
29
+ 为空时表示不启用后台接口认证。
30
+ """
31
+ api_key = get_config("app.api_key", DEFAULT_API_KEY)
32
+ return api_key or ""
33
+
34
+
35
+ def _normalize_api_keys(value: Optional[object]) -> list[str]:
36
+ if not value:
37
+ return []
38
+ if isinstance(value, str):
39
+ raw = value.strip()
40
+ if not raw:
41
+ return []
42
+ return [part.strip() for part in raw.split(",") if part.strip()]
43
+ if isinstance(value, Iterable):
44
+ keys: list[str] = []
45
+ for item in value:
46
+ if not item:
47
+ continue
48
+ if isinstance(item, str):
49
+ stripped = item.strip()
50
+ if stripped:
51
+ keys.append(stripped)
52
+ return keys
53
+ return []
54
+
55
+ def get_app_key() -> str:
56
+ """
57
+ 获取 App Key(后台管理密码)。
58
+ """
59
+ app_key = get_config("app.app_key", DEFAULT_APP_KEY)
60
+ return app_key or ""
61
+
62
+ def get_public_api_key() -> str:
63
+ """
64
+ 获取 Public API Key。
65
+
66
+ 为空时表示不启用 public 接口认证。
67
+ """
68
+ public_key = get_config("app.public_key", DEFAULT_PUBLIC_KEY)
69
+ return public_key or ""
70
+
71
+ def is_public_enabled() -> bool:
72
+ """
73
+ 是否开启 public 功能入口。
74
+ """
75
+ return bool(get_config("app.public_enabled", DEFAULT_PUBLIC_ENABLED))
76
+
77
+
78
+ def _hash_public_key(key: str) -> str:
79
+ """计算 public_key 的 SHA-256 哈希,与前端 hashPublicKey 保持一致。"""
80
+ return hashlib.sha256(f"grok2api-public:{key}".encode()).hexdigest()
81
+
82
+
83
+ def _match_public_key(credentials: str, public_key: str) -> bool:
84
+ """检查凭证是否匹配 public_key(支持原始值和 public-<sha256> 哈希格式)。"""
85
+ if not public_key:
86
+ return False
87
+ normalized = public_key.strip()
88
+ if not normalized:
89
+ return False
90
+ if credentials == normalized:
91
+ return True
92
+ if credentials.startswith("public-"):
93
+ expected_hash = _hash_public_key(normalized)
94
+ if credentials == f"public-{expected_hash}":
95
+ return True
96
+ return False
97
+
98
+
99
+ async def verify_api_key(
100
+ auth: Optional[HTTPAuthorizationCredentials] = Security(security),
101
+ ) -> Optional[str]:
102
+ """
103
+ 验证 Bearer Token
104
+
105
+ 如果 config.toml 中未配置 api_key,则不启用认证。
106
+ """
107
+ api_key = get_admin_api_key()
108
+ api_keys = _normalize_api_keys(api_key)
109
+ if not api_keys:
110
+ return None
111
+
112
+ if not auth:
113
+ raise HTTPException(
114
+ status_code=status.HTTP_401_UNAUTHORIZED,
115
+ detail="Missing authentication token",
116
+ headers={"WWW-Authenticate": "Bearer"},
117
+ )
118
+
119
+ # 标准 api_key 验证
120
+ if auth.credentials in api_keys:
121
+ return auth.credentials
122
+
123
+ raise HTTPException(
124
+ status_code=status.HTTP_401_UNAUTHORIZED,
125
+ detail="Invalid authentication token",
126
+ headers={"WWW-Authenticate": "Bearer"},
127
+ )
128
+
129
+
130
+ async def verify_app_key(
131
+ auth: Optional[HTTPAuthorizationCredentials] = Security(security),
132
+ ) -> Optional[str]:
133
+ """
134
+ 验证后台登录密钥(app_key)。
135
+
136
+ app_key 必须配置,否则拒绝登录。
137
+ """
138
+ app_key = get_app_key()
139
+
140
+ if not app_key:
141
+ raise HTTPException(
142
+ status_code=status.HTTP_401_UNAUTHORIZED,
143
+ detail="App key is not configured",
144
+ headers={"WWW-Authenticate": "Bearer"},
145
+ )
146
+
147
+ if not auth:
148
+ raise HTTPException(
149
+ status_code=status.HTTP_401_UNAUTHORIZED,
150
+ detail="Missing authentication token",
151
+ headers={"WWW-Authenticate": "Bearer"},
152
+ )
153
+
154
+ if auth.credentials != app_key:
155
+ raise HTTPException(
156
+ status_code=status.HTTP_401_UNAUTHORIZED,
157
+ detail="Invalid authentication token",
158
+ headers={"WWW-Authenticate": "Bearer"},
159
+ )
160
+
161
+ return auth.credentials
162
+
163
+
164
+ async def verify_public_key(
165
+ auth: Optional[HTTPAuthorizationCredentials] = Security(security),
166
+ ) -> Optional[str]:
167
+ """
168
+ 验证 Public Key(public 接口使用)。
169
+
170
+ 默认不公开,需配置 public_key 才能访问;若开启 public_enabled 且未配置 public_key,则放开访问。
171
+ """
172
+ public_key = get_public_api_key()
173
+ public_enabled = is_public_enabled()
174
+
175
+ if not public_key:
176
+ if public_enabled:
177
+ return None
178
+ raise HTTPException(
179
+ status_code=status.HTTP_401_UNAUTHORIZED,
180
+ detail="Public access is disabled",
181
+ headers={"WWW-Authenticate": "Bearer"},
182
+ )
183
+
184
+ if not auth:
185
+ raise HTTPException(
186
+ status_code=status.HTTP_401_UNAUTHORIZED,
187
+ detail="Missing authentication token",
188
+ headers={"WWW-Authenticate": "Bearer"},
189
+ )
190
+
191
+ if _match_public_key(auth.credentials, public_key):
192
+ return auth.credentials
193
+
194
+ raise HTTPException(
195
+ status_code=status.HTTP_401_UNAUTHORIZED,
196
+ detail="Invalid authentication token",
197
+ headers={"WWW-Authenticate": "Bearer"},
198
+ )
app/core/batch.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch utilities.
3
+
4
+ - run_batch: generic batch concurrency runner
5
+ - BatchTask: SSE task manager for admin batch operations
6
+ """
7
+
8
+ import asyncio
9
+ import time
10
+ import uuid
11
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
12
+
13
+ from app.core.logger import logger
14
+
15
+ T = TypeVar("T")
16
+
17
+
18
+ async def run_batch(
19
+ items: List[str],
20
+ worker: Callable[[str], Awaitable[T]],
21
+ *,
22
+ batch_size: int = 50,
23
+ task: Optional["BatchTask"] = None,
24
+ on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
25
+ should_cancel: Optional[Callable[[], bool]] = None,
26
+ ) -> Dict[str, Dict[str, Any]]:
27
+ """
28
+ 分批并发执行,单项失败不影响整体
29
+
30
+ Args:
31
+ items: 待处理项列表
32
+ worker: 异步处理函数
33
+ batch_size: 每批大小
34
+
35
+ Returns:
36
+ {item: {"ok": bool, "data": ..., "error": ...}}
37
+ """
38
+ try:
39
+ batch_size = int(batch_size)
40
+ except Exception:
41
+ batch_size = 50
42
+
43
+ batch_size = max(1, batch_size)
44
+
45
+ async def _one(item: str) -> tuple[str, dict]:
46
+ if (should_cancel and should_cancel()) or (task and task.cancelled):
47
+ return item, {"ok": False, "error": "cancelled", "cancelled": True}
48
+ try:
49
+ data = await worker(item)
50
+ result = {"ok": True, "data": data}
51
+ if task:
52
+ task.record(True)
53
+ if on_item:
54
+ try:
55
+ await on_item(item, result)
56
+ except Exception:
57
+ pass
58
+ return item, result
59
+ except Exception as e:
60
+ logger.warning(f"Batch item failed: {item[:16]}... - {e}")
61
+ result = {"ok": False, "error": str(e)}
62
+ if task:
63
+ task.record(False, error=str(e))
64
+ if on_item:
65
+ try:
66
+ await on_item(item, result)
67
+ except Exception:
68
+ pass
69
+ return item, result
70
+
71
+ results: Dict[str, dict] = {}
72
+
73
+ # 分批执行,避免一次性创建所有 task
74
+ for i in range(0, len(items), batch_size):
75
+ if (should_cancel and should_cancel()) or (task and task.cancelled):
76
+ break
77
+ chunk = items[i : i + batch_size]
78
+ pairs = await asyncio.gather(*(_one(x) for x in chunk))
79
+ results.update(dict(pairs))
80
+
81
+ return results
82
+
83
+
84
+ class BatchTask:
85
+ def __init__(self, total: int):
86
+ self.id = uuid.uuid4().hex
87
+ self.total = int(total)
88
+ self.processed = 0
89
+ self.ok = 0
90
+ self.fail = 0
91
+ self.status = "running"
92
+ self.warning: Optional[str] = None
93
+ self.result: Optional[Dict[str, Any]] = None
94
+ self.error: Optional[str] = None
95
+ self.created_at = time.time()
96
+ self._queues: List[asyncio.Queue] = []
97
+ self._final_event: Optional[Dict[str, Any]] = None
98
+ self.cancelled = False
99
+
100
+ def snapshot(self) -> Dict[str, Any]:
101
+ return {
102
+ "task_id": self.id,
103
+ "status": self.status,
104
+ "total": self.total,
105
+ "processed": self.processed,
106
+ "ok": self.ok,
107
+ "fail": self.fail,
108
+ "warning": self.warning,
109
+ }
110
+
111
+ def attach(self) -> asyncio.Queue:
112
+ q: asyncio.Queue = asyncio.Queue(maxsize=200)
113
+ self._queues.append(q)
114
+ return q
115
+
116
+ def detach(self, q: asyncio.Queue) -> None:
117
+ if q in self._queues:
118
+ self._queues.remove(q)
119
+
120
+ def _publish(self, event: Dict[str, Any]) -> None:
121
+ for q in list(self._queues):
122
+ try:
123
+ q.put_nowait(event)
124
+ except Exception:
125
+ # Drop if queue is full or closed
126
+ pass
127
+
128
+ def record(
129
+ self, ok: bool, *, item: Any = None, detail: Any = None, error: str = ""
130
+ ) -> None:
131
+ self.processed += 1
132
+ if ok:
133
+ self.ok += 1
134
+ else:
135
+ self.fail += 1
136
+ event: Dict[str, Any] = {
137
+ "type": "progress",
138
+ "task_id": self.id,
139
+ "total": self.total,
140
+ "processed": self.processed,
141
+ "ok": self.ok,
142
+ "fail": self.fail,
143
+ }
144
+ if item is not None:
145
+ event["item"] = item
146
+ if detail is not None:
147
+ event["detail"] = detail
148
+ if error:
149
+ event["error"] = error
150
+ self._publish(event)
151
+
152
+ def finish(self, result: Dict[str, Any], *, warning: Optional[str] = None) -> None:
153
+ self.status = "done"
154
+ self.result = result
155
+ self.warning = warning
156
+ event = {
157
+ "type": "done",
158
+ "task_id": self.id,
159
+ "total": self.total,
160
+ "processed": self.processed,
161
+ "ok": self.ok,
162
+ "fail": self.fail,
163
+ "warning": self.warning,
164
+ "result": result,
165
+ }
166
+ self._final_event = event
167
+ self._publish(event)
168
+
169
+ def fail_task(self, error: str) -> None:
170
+ self.status = "error"
171
+ self.error = error
172
+ event = {
173
+ "type": "error",
174
+ "task_id": self.id,
175
+ "total": self.total,
176
+ "processed": self.processed,
177
+ "ok": self.ok,
178
+ "fail": self.fail,
179
+ "error": error,
180
+ }
181
+ self._final_event = event
182
+ self._publish(event)
183
+
184
+ def cancel(self) -> None:
185
+ self.cancelled = True
186
+
187
+ def finish_cancelled(self) -> None:
188
+ self.status = "cancelled"
189
+ event = {
190
+ "type": "cancelled",
191
+ "task_id": self.id,
192
+ "total": self.total,
193
+ "processed": self.processed,
194
+ "ok": self.ok,
195
+ "fail": self.fail,
196
+ }
197
+ self._final_event = event
198
+ self._publish(event)
199
+
200
+ def final_event(self) -> Optional[Dict[str, Any]]:
201
+ return self._final_event
202
+
203
+
204
+ _TASKS: Dict[str, BatchTask] = {}
205
+
206
+
207
+ def create_task(total: int) -> BatchTask:
208
+ task = BatchTask(total)
209
+ _TASKS[task.id] = task
210
+ return task
211
+
212
+
213
+ def get_task(task_id: str) -> Optional[BatchTask]:
214
+ return _TASKS.get(task_id)
215
+
216
+
217
+ def delete_task(task_id: str) -> None:
218
+ _TASKS.pop(task_id, None)
219
+
220
+
221
+ async def expire_task(task_id: str, delay: int = 300) -> None:
222
+ await asyncio.sleep(delay)
223
+ delete_task(task_id)
224
+
225
+
226
+ __all__ = [
227
+ "run_batch",
228
+ "BatchTask",
229
+ "create_task",
230
+ "get_task",
231
+ "delete_task",
232
+ "expire_task",
233
+ ]
app/core/config.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 配置管理
3
+
4
+ - config.toml: 运行时配置
5
+ - config.defaults.toml: 默认配置基线
6
+ """
7
+
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+ from typing import Any, Dict
11
+ import tomllib
12
+
13
+ from app.core.logger import logger
14
+
15
+ DEFAULT_CONFIG_FILE = Path(__file__).parent.parent.parent / "config.defaults.toml"
16
+
17
+
18
+ def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
19
+ """深度合并字典: override 覆盖 base."""
20
+ if not isinstance(base, dict):
21
+ return deepcopy(override) if isinstance(override, dict) else deepcopy(base)
22
+
23
+ result = deepcopy(base)
24
+ if not isinstance(override, dict):
25
+ return result
26
+
27
+ for key, val in override.items():
28
+ if isinstance(val, dict) and isinstance(result.get(key), dict):
29
+ result[key] = _deep_merge(result[key], val)
30
+ else:
31
+ result[key] = val
32
+ return result
33
+
34
+
35
+ def _migrate_deprecated_config(
36
+ config: Dict[str, Any], valid_sections: set
37
+ ) -> tuple[Dict[str, Any], set]:
38
+ """
39
+ 迁移废弃的配置节到新配置结构
40
+
41
+ Returns:
42
+ (迁移后的配置, 废弃的配置节集合)
43
+ """
44
+ # 配置映射规则:旧配置 -> 新配置
45
+ MIGRATION_MAP = {
46
+ # grok.* -> 对应的新配置节
47
+ "grok.temporary": "app.temporary",
48
+ "grok.disable_memory": "app.disable_memory",
49
+ "grok.stream": "app.stream",
50
+ "grok.thinking": "app.thinking",
51
+ "grok.dynamic_statsig": "app.dynamic_statsig",
52
+ "grok.filter_tags": "app.filter_tags",
53
+ "grok.timeout": "voice.timeout",
54
+ "grok.base_proxy_url": "proxy.base_proxy_url",
55
+ "grok.asset_proxy_url": "proxy.asset_proxy_url",
56
+ "network.base_proxy_url": "proxy.base_proxy_url",
57
+ "network.asset_proxy_url": "proxy.asset_proxy_url",
58
+ "grok.cf_clearance": "proxy.cf_clearance",
59
+ "grok.browser": "proxy.browser",
60
+ "grok.user_agent": "proxy.user_agent",
61
+ "security.cf_clearance": "proxy.cf_clearance",
62
+ "security.browser": "proxy.browser",
63
+ "security.user_agent": "proxy.user_agent",
64
+ "grok.max_retry": "retry.max_retry",
65
+ "grok.retry_status_codes": "retry.retry_status_codes",
66
+ "grok.retry_backoff_base": "retry.retry_backoff_base",
67
+ "grok.retry_backoff_factor": "retry.retry_backoff_factor",
68
+ "grok.retry_backoff_max": "retry.retry_backoff_max",
69
+ "grok.retry_budget": "retry.retry_budget",
70
+ "grok.video_idle_timeout": "video.stream_timeout",
71
+ "grok.image_ws_nsfw": "image.nsfw",
72
+ "grok.image_ws_blocked_seconds": "image.final_timeout",
73
+ "grok.image_ws_final_min_bytes": "image.final_min_bytes",
74
+ "grok.image_ws_medium_min_bytes": "image.medium_min_bytes",
75
+ # legacy sections
76
+ "network.base_proxy_url": "proxy.base_proxy_url",
77
+ "network.asset_proxy_url": "proxy.asset_proxy_url",
78
+ "network.timeout": [
79
+ "chat.timeout",
80
+ "image.timeout",
81
+ "video.timeout",
82
+ "voice.timeout",
83
+ ],
84
+ "security.cf_clearance": "proxy.cf_clearance",
85
+ "security.browser": "proxy.browser",
86
+ "security.user_agent": "proxy.user_agent",
87
+ "timeout.stream_idle_timeout": [
88
+ "chat.stream_timeout",
89
+ "image.stream_timeout",
90
+ "video.stream_timeout",
91
+ ],
92
+ "timeout.video_idle_timeout": "video.stream_timeout",
93
+ "image.image_ws_nsfw": "image.nsfw",
94
+ "image.image_ws_blocked_seconds": "image.final_timeout",
95
+ "image.image_ws_final_min_bytes": "image.final_min_bytes",
96
+ "image.image_ws_medium_min_bytes": "image.medium_min_bytes",
97
+ "performance.assets_max_concurrent": [
98
+ "asset.upload_concurrent",
99
+ "asset.download_concurrent",
100
+ "asset.list_concurrent",
101
+ "asset.delete_concurrent",
102
+ ],
103
+ "performance.assets_delete_batch_size": "asset.delete_batch_size",
104
+ "performance.assets_batch_size": "asset.list_batch_size",
105
+ "performance.media_max_concurrent": ["chat.concurrent", "video.concurrent"],
106
+ "performance.usage_max_concurrent": "usage.concurrent",
107
+ "performance.usage_batch_size": "usage.batch_size",
108
+ "performance.nsfw_max_concurrent": "nsfw.concurrent",
109
+ "performance.nsfw_batch_size": "nsfw.batch_size",
110
+ }
111
+
112
+ deprecated_sections = set(config.keys()) - valid_sections
113
+ if not deprecated_sections:
114
+ return config, set()
115
+
116
+ result = {k: deepcopy(v) for k, v in config.items() if k in valid_sections}
117
+ migrated_count = 0
118
+
119
+ # 处理废弃配置节或旧配置键
120
+ for old_section, old_values in config.items():
121
+ if not isinstance(old_values, dict):
122
+ continue
123
+ for old_key, old_value in old_values.items():
124
+ old_path = f"{old_section}.{old_key}"
125
+ new_paths = MIGRATION_MAP.get(old_path)
126
+ if not new_paths:
127
+ continue
128
+ if isinstance(new_paths, str):
129
+ new_paths = [new_paths]
130
+ for new_path in new_paths:
131
+ try:
132
+ new_section, new_key = new_path.split(".", 1)
133
+ if new_section not in result:
134
+ result[new_section] = {}
135
+ if new_key not in result[new_section]:
136
+ result[new_section][new_key] = old_value
137
+ migrated_count += 1
138
+ logger.debug(
139
+ f"Migrated config: {old_path} -> {new_path} = {old_value}"
140
+ )
141
+ except Exception as e:
142
+ logger.warning(
143
+ f"Skip config migration for {old_path}: {e}"
144
+ )
145
+ continue
146
+ if isinstance(result.get(old_section), dict):
147
+ result[old_section].pop(old_key, None)
148
+
149
+ # 兼容旧 chat.* 配置键迁移到 app.*
150
+ legacy_chat_map = {
151
+ "temporary": "temporary",
152
+ "disable_memory": "disable_memory",
153
+ "stream": "stream",
154
+ "thinking": "thinking",
155
+ "dynamic_statsig": "dynamic_statsig",
156
+ "filter_tags": "filter_tags",
157
+ }
158
+ chat_section = config.get("chat")
159
+ if isinstance(chat_section, dict):
160
+ app_section = result.setdefault("app", {})
161
+ for old_key, new_key in legacy_chat_map.items():
162
+ if old_key in chat_section and new_key not in app_section:
163
+ app_section[new_key] = chat_section[old_key]
164
+ if isinstance(result.get("chat"), dict):
165
+ result["chat"].pop(old_key, None)
166
+ migrated_count += 1
167
+ logger.debug(
168
+ f"Migrated config: chat.{old_key} -> app.{new_key} = {chat_section[old_key]}"
169
+ )
170
+
171
+ if migrated_count > 0:
172
+ logger.info(
173
+ f"Migrated {migrated_count} config items from deprecated/legacy sections"
174
+ )
175
+
176
+ return result, deprecated_sections
177
+
178
+
179
+ def _load_defaults() -> Dict[str, Any]:
180
+ """加载默认配置文件"""
181
+ if not DEFAULT_CONFIG_FILE.exists():
182
+ return {}
183
+ try:
184
+ with DEFAULT_CONFIG_FILE.open("rb") as f:
185
+ return tomllib.load(f)
186
+ except Exception as e:
187
+ logger.warning(f"Failed to load defaults from {DEFAULT_CONFIG_FILE}: {e}")
188
+ return {}
189
+
190
+
191
+ class Config:
192
+ """配置管理器"""
193
+
194
+ _instance = None
195
+ _config = {}
196
+
197
+ def __init__(self):
198
+ self._config = {}
199
+ self._defaults = {}
200
+ self._code_defaults = {}
201
+ self._defaults_loaded = False
202
+
203
+ def register_defaults(self, defaults: Dict[str, Any]):
204
+ """注册代码中定义的默认值"""
205
+ self._code_defaults = _deep_merge(self._code_defaults, defaults)
206
+
207
+ def _ensure_defaults(self):
208
+ if self._defaults_loaded:
209
+ return
210
+ file_defaults = _load_defaults()
211
+ # 合并文件默认值和代码默认值(代码默认值优先级更低)
212
+ self._defaults = _deep_merge(self._code_defaults, file_defaults)
213
+ self._defaults_loaded = True
214
+
215
+ async def load(self):
216
+ """显式加载配置"""
217
+ try:
218
+ from app.core.storage import get_storage, LocalStorage
219
+
220
+ self._ensure_defaults()
221
+
222
+ storage = get_storage()
223
+ config_data = await storage.load_config()
224
+ from_remote = True
225
+
226
+ # 从本地 data/config.toml 初始化后端
227
+ if config_data is None:
228
+ local_storage = LocalStorage()
229
+ from_remote = False
230
+ try:
231
+ # 尝试读取本地配置
232
+ config_data = await local_storage.load_config()
233
+ except Exception as e:
234
+ logger.info(f"Failed to auto-init config from local: {e}")
235
+ config_data = {}
236
+
237
+ config_data = config_data or {}
238
+
239
+ # 检查是否有废弃的配置节
240
+ valid_sections = set(self._defaults.keys())
241
+ config_data, deprecated_sections = _migrate_deprecated_config(
242
+ config_data, valid_sections
243
+ )
244
+ if deprecated_sections:
245
+ logger.info(
246
+ f"Cleaned deprecated config sections: {deprecated_sections}"
247
+ )
248
+
249
+ merged = _deep_merge(self._defaults, config_data)
250
+
251
+ # 自动回填缺失配置到存储
252
+ # 或迁移了配置后需要更新
253
+ # 保护:当远程存储返回 None 且本地也没有可迁移配置时,不覆盖远程配置,避免误重置。
254
+ has_local_seed = bool(config_data)
255
+ allow_bootstrap_empty_remote = (
256
+ (not from_remote) and has_local_seed
257
+ )
258
+ should_persist = (
259
+ allow_bootstrap_empty_remote
260
+ or (merged != config_data and bool(config_data))
261
+ or deprecated_sections
262
+ )
263
+ if should_persist:
264
+ async with storage.acquire_lock("config_save", timeout=10):
265
+ await storage.save_config(merged)
266
+ if not from_remote and has_local_seed:
267
+ logger.info(
268
+ f"Initialized remote storage ({storage.__class__.__name__}) with config baseline."
269
+ )
270
+ if deprecated_sections:
271
+ logger.info("Configuration automatically migrated and cleaned.")
272
+ elif not from_remote and not has_local_seed:
273
+ logger.warning(
274
+ "Skip persisting defaults: empty config source detected, keep runtime merged config only."
275
+ )
276
+
277
+ self._config = merged
278
+ except Exception as e:
279
+ logger.error(f"Error loading config: {e}")
280
+ self._config = {}
281
+
282
+ def get(self, key: str, default: Any = None) -> Any:
283
+ """
284
+ 获取配置值
285
+
286
+ Args:
287
+ key: 配置键,格式 "section.key"
288
+ default: 默认值
289
+ """
290
+ if "." in key:
291
+ try:
292
+ section, attr = key.split(".", 1)
293
+ return self._config.get(section, {}).get(attr, default)
294
+ except (ValueError, AttributeError):
295
+ return default
296
+
297
+ return self._config.get(key, default)
298
+
299
+ async def update(self, new_config: dict):
300
+ """更新配置"""
301
+ from app.core.storage import get_storage
302
+
303
+ storage = get_storage()
304
+ async with storage.acquire_lock("config_save", timeout=10):
305
+ self._ensure_defaults()
306
+ base = _deep_merge(self._defaults, self._config or {})
307
+ merged = _deep_merge(base, new_config or {})
308
+ await storage.save_config(merged)
309
+ self._config = merged
310
+
311
+
312
+ # 全局配置实例
313
+ config = Config()
314
+
315
+
316
+ def get_config(key: str, default: Any = None) -> Any:
317
+ """获取配置"""
318
+ return config.get(key, default)
319
+
320
+
321
+ def register_defaults(defaults: Dict[str, Any]):
322
+ """注册默认配置"""
323
+ config.register_defaults(defaults)
324
+
325
+
326
+ __all__ = ["Config", "config", "get_config", "register_defaults"]
app/core/exceptions.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 全局异常处理 - OpenAI 兼容错误格式
3
+ """
4
+
5
+ from typing import Any
6
+ from enum import Enum
7
+ from fastapi import Request, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.exceptions import RequestValidationError
10
+
11
+ from app.core.logger import logger
12
+
13
+
14
+ # ============= 错误类型 =============
15
+
16
+
17
+ class ErrorType(str, Enum):
18
+ """OpenAI 错误类型"""
19
+
20
+ INVALID_REQUEST = "invalid_request_error"
21
+ AUTHENTICATION = "authentication_error"
22
+ PERMISSION = "permission_error"
23
+ NOT_FOUND = "not_found_error"
24
+ RATE_LIMIT = "rate_limit_error"
25
+ SERVER = "server_error"
26
+ SERVICE_UNAVAILABLE = "service_unavailable_error"
27
+
28
+
29
+ # ============= 辅助函数 =============
30
+
31
+
32
+ def error_response(
33
+ message: str,
34
+ error_type: str = ErrorType.INVALID_REQUEST.value,
35
+ param: str = None,
36
+ code: str = None,
37
+ ) -> dict:
38
+ """构建 OpenAI 错误响应"""
39
+ return {
40
+ "error": {"message": message, "type": error_type, "param": param, "code": code}
41
+ }
42
+
43
+
44
+ # ============= 异常类 =============
45
+
46
+
47
+ class AppException(Exception):
48
+ """应用基础异常"""
49
+
50
+ def __init__(
51
+ self,
52
+ message: str,
53
+ error_type: str = ErrorType.SERVER.value,
54
+ code: str = None,
55
+ param: str = None,
56
+ status_code: int = 500,
57
+ ):
58
+ self.message = message
59
+ self.error_type = error_type
60
+ self.code = code
61
+ self.param = param
62
+ self.status_code = status_code
63
+ super().__init__(message)
64
+
65
+
66
+ class ValidationException(AppException):
67
+ """验证错误"""
68
+
69
+ def __init__(self, message: str, param: str = None, code: str = None):
70
+ super().__init__(
71
+ message=message,
72
+ error_type=ErrorType.INVALID_REQUEST.value,
73
+ code=code or "invalid_value",
74
+ param=param,
75
+ status_code=400,
76
+ )
77
+
78
+
79
+ class AuthenticationException(AppException):
80
+ """认证错误"""
81
+
82
+ def __init__(self, message: str = "Invalid API key"):
83
+ super().__init__(
84
+ message=message,
85
+ error_type=ErrorType.AUTHENTICATION.value,
86
+ code="invalid_api_key",
87
+ status_code=401,
88
+ )
89
+
90
+
91
+ class UpstreamException(AppException):
92
+ """上游服务错误"""
93
+
94
+ def __init__(self, message: str, details: Any = None):
95
+ super().__init__(
96
+ message=message,
97
+ error_type=ErrorType.SERVER.value,
98
+ code="upstream_error",
99
+ status_code=502,
100
+ )
101
+ self.details = details
102
+
103
+
104
+ class StreamIdleTimeoutError(Exception):
105
+ """流空闲超时错误"""
106
+
107
+ def __init__(self, idle_seconds: float):
108
+ self.idle_seconds = idle_seconds
109
+ super().__init__(f"Stream idle timeout after {idle_seconds}s")
110
+
111
+
112
+ # ============= 异常处理器 =============
113
+
114
+
115
+ async def app_exception_handler(request: Request, exc: AppException) -> JSONResponse:
116
+ """处理应用异常"""
117
+ logger.warning(f"AppException: {exc.error_type} - {exc.message}")
118
+
119
+ return JSONResponse(
120
+ status_code=exc.status_code,
121
+ content=error_response(
122
+ message=exc.message,
123
+ error_type=exc.error_type,
124
+ param=exc.param,
125
+ code=exc.code,
126
+ ),
127
+ )
128
+
129
+
130
+ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
131
+ """处理 HTTP 异常"""
132
+ type_map = {
133
+ 400: ErrorType.INVALID_REQUEST.value,
134
+ 401: ErrorType.AUTHENTICATION.value,
135
+ 403: ErrorType.PERMISSION.value,
136
+ 404: ErrorType.NOT_FOUND.value,
137
+ 429: ErrorType.RATE_LIMIT.value,
138
+ }
139
+ error_type = type_map.get(exc.status_code, ErrorType.SERVER.value)
140
+
141
+ # 默认 code 映射
142
+ code_map = {
143
+ 401: "invalid_api_key",
144
+ 403: "insufficient_quota",
145
+ 404: "model_not_found",
146
+ 429: "rate_limit_exceeded",
147
+ }
148
+ code = code_map.get(exc.status_code, None)
149
+
150
+ logger.warning(f"HTTPException: {exc.status_code} - {exc.detail}")
151
+
152
+ return JSONResponse(
153
+ status_code=exc.status_code,
154
+ content=error_response(
155
+ message=str(exc.detail), error_type=error_type, code=code
156
+ ),
157
+ )
158
+
159
+
160
+ async def validation_exception_handler(
161
+ request: Request, exc: RequestValidationError
162
+ ) -> JSONResponse:
163
+ """处理验证错误"""
164
+ errors = exc.errors()
165
+
166
+ if errors:
167
+ first = errors[0]
168
+ loc = first.get("loc", [])
169
+ msg = first.get("msg", "Invalid request")
170
+ code = first.get("type", "invalid_value")
171
+
172
+ # JSON 解析错误
173
+ if code == "json_invalid" or "JSON" in msg:
174
+ message = "Invalid JSON in request body. Please check for trailing commas or syntax errors."
175
+ param = "body"
176
+ else:
177
+ param_parts = [
178
+ str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())
179
+ ]
180
+ param = ".".join(param_parts) if param_parts else None
181
+ message = msg
182
+ else:
183
+ param, message, code = None, "Invalid request", "invalid_value"
184
+
185
+ logger.warning(f"ValidationError: {param} - {message}")
186
+
187
+ return JSONResponse(
188
+ status_code=400,
189
+ content=error_response(
190
+ message=message,
191
+ error_type=ErrorType.INVALID_REQUEST.value,
192
+ param=param,
193
+ code=code,
194
+ ),
195
+ )
196
+
197
+
198
+ async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
199
+ """处理未捕获异常"""
200
+ logger.exception(f"Unhandled: {type(exc).__name__}: {str(exc)}")
201
+
202
+ return JSONResponse(
203
+ status_code=500,
204
+ content=error_response(
205
+ message="Internal server error",
206
+ error_type=ErrorType.SERVER.value,
207
+ code="internal_error",
208
+ ),
209
+ )
210
+
211
+
212
+ # ============= 注册 =============
213
+
214
+
215
+ def register_exception_handlers(app):
216
+ """注册异常处理器"""
217
+ app.add_exception_handler(AppException, app_exception_handler)
218
+ app.add_exception_handler(HTTPException, http_exception_handler)
219
+ app.add_exception_handler(RequestValidationError, validation_exception_handler)
220
+ app.add_exception_handler(Exception, generic_exception_handler)
221
+
222
+
223
+ __all__ = [
224
+ "ErrorType",
225
+ "AppException",
226
+ "ValidationException",
227
+ "AuthenticationException",
228
+ "UpstreamException",
229
+ "StreamIdleTimeoutError",
230
+ "error_response",
231
+ "register_exception_handlers",
232
+ ]
app/core/logger.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 结构化 JSON 日志 - 极简格式
3
+ """
4
+
5
+ import sys
6
+ import os
7
+ import json
8
+ import traceback
9
+ from pathlib import Path
10
+ from loguru import logger
11
+
12
+ # Provide logging.Logger compatibility for legacy calls
13
+ if not hasattr(logger, "isEnabledFor"):
14
+ logger.isEnabledFor = lambda _level: True
15
+
16
+ # 日志目录
17
+ DEFAULT_LOG_DIR = Path(__file__).parent.parent.parent / "logs"
18
+ LOG_DIR = Path(os.getenv("LOG_DIR", str(DEFAULT_LOG_DIR)))
19
+ _LOG_DIR_READY = False
20
+
21
+
22
+ def _prepare_log_dir() -> bool:
23
+ """确保日志目录可用"""
24
+ global LOG_DIR, _LOG_DIR_READY
25
+ if _LOG_DIR_READY:
26
+ return True
27
+ try:
28
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
29
+ _LOG_DIR_READY = True
30
+ return True
31
+ except Exception:
32
+ _LOG_DIR_READY = False
33
+ return False
34
+
35
+
36
+ def _format_json(record) -> str:
37
+ """格式化日志"""
38
+ # ISO8601 时间
39
+ time_str = record["time"].strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
40
+ tz = record["time"].strftime("%z")
41
+ if tz:
42
+ time_str += tz[:3] + ":" + tz[3:]
43
+
44
+ log_entry = {
45
+ "time": time_str,
46
+ "level": record["level"].name.lower(),
47
+ "msg": record["message"],
48
+ "caller": f"{record['file'].name}:{record['line']}",
49
+ }
50
+
51
+ # trace 上下文
52
+ extra = record["extra"]
53
+ if extra.get("traceID"):
54
+ log_entry["traceID"] = extra["traceID"]
55
+ if extra.get("spanID"):
56
+ log_entry["spanID"] = extra["spanID"]
57
+
58
+ # 其他 extra 字段
59
+ for key, value in extra.items():
60
+ if key not in ("traceID", "spanID") and not key.startswith("_"):
61
+ log_entry[key] = value
62
+
63
+ # 错误及以上级别添加堆栈跟踪
64
+ if record["level"].no >= 40 and record["exception"]:
65
+ log_entry["stacktrace"] = "".join(
66
+ traceback.format_exception(
67
+ record["exception"].type,
68
+ record["exception"].value,
69
+ record["exception"].traceback,
70
+ )
71
+ )
72
+
73
+ return json.dumps(log_entry, ensure_ascii=False)
74
+
75
+ def _env_flag(name: str, default: bool) -> bool:
76
+ raw = os.getenv(name)
77
+ if raw is None:
78
+ return default
79
+ return raw.strip().lower() in ("1", "true", "yes", "on", "y")
80
+
81
+
82
+ def _make_json_sink(output):
83
+ """创建 JSON sink"""
84
+
85
+ def sink(message):
86
+ json_str = _format_json(message.record)
87
+ print(json_str, file=output, flush=True)
88
+
89
+ return sink
90
+
91
+
92
+ def _file_json_sink(message):
93
+ """写入日志文件"""
94
+ record = message.record
95
+ json_str = _format_json(record)
96
+ log_file = LOG_DIR / f"app_{record['time'].strftime('%Y-%m-%d')}.log"
97
+ with open(log_file, "a", encoding="utf-8") as f:
98
+ f.write(json_str + "\n")
99
+
100
+
101
+ def setup_logging(
102
+ level: str = "DEBUG",
103
+ json_console: bool = True,
104
+ file_logging: bool = True,
105
+ ):
106
+ """设置日志配置"""
107
+ logger.remove()
108
+ file_logging = _env_flag("LOG_FILE_ENABLED", file_logging)
109
+
110
+ # 控制台输出
111
+ if json_console:
112
+ logger.add(
113
+ _make_json_sink(sys.stdout),
114
+ level=level,
115
+ format="{message}",
116
+ colorize=False,
117
+ )
118
+ else:
119
+ logger.add(
120
+ sys.stdout,
121
+ level=level,
122
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{file.name}:{line}</cyan> - <level>{message}</level>",
123
+ colorize=True,
124
+ )
125
+
126
+ # 文件输出
127
+ if file_logging:
128
+ if _prepare_log_dir():
129
+ logger.add(
130
+ _file_json_sink,
131
+ level=level,
132
+ format="{message}",
133
+ enqueue=True,
134
+ )
135
+ else:
136
+ logger.warning("File logging disabled: no writable log directory.")
137
+
138
+ return logger
139
+
140
+
141
+ def get_logger(trace_id: str = "", span_id: str = ""):
142
+ """获取绑定了 trace 上下文的 logger"""
143
+ bound = {}
144
+ if trace_id:
145
+ bound["traceID"] = trace_id
146
+ if span_id:
147
+ bound["spanID"] = span_id
148
+ return logger.bind(**bound) if bound else logger
149
+
150
+
151
+ __all__ = ["logger", "setup_logging", "get_logger", "LOG_DIR"]
app/core/response_middleware.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 响应中间件
3
+ Response Middleware
4
+
5
+ 用于记录请求日志、生成 TraceID 和计算请求耗时
6
+ """
7
+
8
+ import time
9
+ import uuid
10
+ from starlette.middleware.base import BaseHTTPMiddleware
11
+ from starlette.requests import Request
12
+
13
+ from app.core.logger import logger
14
+
15
+
16
+ class ResponseLoggerMiddleware(BaseHTTPMiddleware):
17
+ """
18
+ 请求日志/响应追踪中间件
19
+ Request Logging and Response Tracking Middleware
20
+ """
21
+
22
+ async def dispatch(self, request: Request, call_next):
23
+ # 生成请求 ID
24
+ trace_id = str(uuid.uuid4())
25
+ request.state.trace_id = trace_id
26
+
27
+ start_time = time.time()
28
+ path = request.url.path
29
+
30
+ if path.startswith("/static/") or path in (
31
+ "/",
32
+ "/login",
33
+ "/imagine",
34
+ "/voice",
35
+ "/admin",
36
+ "/admin/login",
37
+ "/admin/config",
38
+ "/admin/cache",
39
+ "/admin/token",
40
+ ):
41
+ return await call_next(request)
42
+
43
+ # 记录请求信息
44
+ logger.info(
45
+ f"Request: {request.method} {request.url.path}",
46
+ extra={
47
+ "traceID": trace_id,
48
+ "method": request.method,
49
+ "path": request.url.path,
50
+ },
51
+ )
52
+
53
+ try:
54
+ response = await call_next(request)
55
+
56
+ # 计算耗时
57
+ duration = (time.time() - start_time) * 1000
58
+
59
+ # 记录响应信息
60
+ logger.info(
61
+ f"Response: {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)",
62
+ extra={
63
+ "traceID": trace_id,
64
+ "method": request.method,
65
+ "path": request.url.path,
66
+ "status": response.status_code,
67
+ "duration_ms": round(duration, 2),
68
+ },
69
+ )
70
+
71
+ return response
72
+
73
+ except Exception as e:
74
+ duration = (time.time() - start_time) * 1000
75
+ logger.error(
76
+ f"Response Error: {request.method} {request.url.path} - {str(e)} ({duration:.2f}ms)",
77
+ extra={
78
+ "traceID": trace_id,
79
+ "method": request.method,
80
+ "path": request.url.path,
81
+ "duration_ms": round(duration, 2),
82
+ "error": str(e),
83
+ },
84
+ )
85
+ raise e
app/core/storage.py ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 统一存储服务 (Professional Storage Service)
3
+ 支持 Local (TOML), Redis, MySQL, PostgreSQL
4
+
5
+ 特性:
6
+ - 全异步 I/O (Async I/O)
7
+ - 连接池管理 (Connection Pooling)
8
+ - 分布式/本地锁 (Distributed/Local Locking)
9
+ - 内存优化 (序列化性能优化)
10
+ """
11
+
12
+ import abc
13
+ import os
14
+ import asyncio
15
+ import hashlib
16
+ import time
17
+ import tomllib
18
+ from typing import Any, ClassVar, Dict, Optional
19
+ from pathlib import Path
20
+ from enum import Enum
21
+
22
+ try:
23
+ import fcntl
24
+ except ImportError: # pragma: no cover - non-posix platforms
25
+ fcntl = None
26
+ from contextlib import asynccontextmanager
27
+
28
+ import orjson
29
+ import aiofiles
30
+ from app.core.logger import logger
31
+
32
+ # 数据目录(支持通过环境变量覆盖)
33
+ DEFAULT_DATA_DIR = Path(__file__).parent.parent.parent / "data"
34
+ DATA_DIR = Path(os.getenv("DATA_DIR", str(DEFAULT_DATA_DIR))).expanduser()
35
+
36
+ # 配置文件路径
37
+ CONFIG_FILE = DATA_DIR / "config.toml"
38
+ TOKEN_FILE = DATA_DIR / "token.json"
39
+ LOCK_DIR = DATA_DIR / ".locks"
40
+
41
+
42
+ # JSON 序列化优化助手函数
43
+ def json_dumps(obj: Any) -> str:
44
+ return orjson.dumps(obj).decode("utf-8")
45
+
46
+
47
+ def json_loads(obj: str | bytes) -> Any:
48
+ return orjson.loads(obj)
49
+
50
+
51
+ def json_dumps_sorted(obj: Any) -> str:
52
+ return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
53
+
54
+
55
+ class StorageError(Exception):
56
+ """存储服务基础异常"""
57
+
58
+ pass
59
+
60
+
61
+ class BaseStorage(abc.ABC):
62
+ """存储基类"""
63
+
64
+ @abc.abstractmethod
65
+ async def load_config(self) -> Dict[str, Any]:
66
+ """加载配置"""
67
+ pass
68
+
69
+ @abc.abstractmethod
70
+ async def save_config(self, data: Dict[str, Any]):
71
+ """保存配置"""
72
+ pass
73
+
74
+ @abc.abstractmethod
75
+ async def load_tokens(self) -> Dict[str, Any]:
76
+ """加载所有 Token"""
77
+ pass
78
+
79
+ @abc.abstractmethod
80
+ async def save_tokens(self, data: Dict[str, Any]):
81
+ """保存所有 Token"""
82
+ pass
83
+
84
+ async def save_tokens_delta(
85
+ self, updated: list[Dict[str, Any]], deleted: Optional[list[str]] = None
86
+ ):
87
+ """增量保存 Token(默认回退到全量保存)"""
88
+ existing = await self.load_tokens() or {}
89
+
90
+ deleted_set = set(deleted or [])
91
+ if deleted_set:
92
+ for pool_name, tokens in list(existing.items()):
93
+ if not isinstance(tokens, list):
94
+ continue
95
+ filtered = []
96
+ for item in tokens:
97
+ if isinstance(item, str):
98
+ token_str = item
99
+ elif isinstance(item, dict):
100
+ token_str = item.get("token")
101
+ else:
102
+ token_str = None
103
+ if token_str and token_str in deleted_set:
104
+ continue
105
+ filtered.append(item)
106
+ existing[pool_name] = filtered
107
+
108
+ for item in updated or []:
109
+ if not isinstance(item, dict):
110
+ continue
111
+ pool_name = item.get("pool_name")
112
+ token_str = item.get("token")
113
+ if not pool_name or not token_str:
114
+ continue
115
+ pool_list = existing.setdefault(pool_name, [])
116
+ normalized = {
117
+ k: v
118
+ for k, v in item.items()
119
+ if k not in ("pool_name", "_update_kind")
120
+ }
121
+ replaced = False
122
+ for idx, current in enumerate(pool_list):
123
+ if isinstance(current, str):
124
+ if current == token_str:
125
+ pool_list[idx] = normalized
126
+ replaced = True
127
+ break
128
+ elif isinstance(current, dict) and current.get("token") == token_str:
129
+ pool_list[idx] = normalized
130
+ replaced = True
131
+ break
132
+ if not replaced:
133
+ pool_list.append(normalized)
134
+
135
+ await self.save_tokens(existing)
136
+
137
+ @abc.abstractmethod
138
+ async def close(self):
139
+ """关闭资源"""
140
+ pass
141
+
142
+ @asynccontextmanager
143
+ async def acquire_lock(self, name: str, timeout: int = 10):
144
+ """
145
+ 获取锁 (互斥访问)
146
+ 用于读写操作的临界区保护
147
+
148
+ Args:
149
+ name: 锁名称
150
+ timeout: 超时时间 (秒)
151
+ """
152
+ # 默认空实现,用于 fallback
153
+ yield
154
+
155
+ async def verify_connection(self) -> bool:
156
+ """健康检查"""
157
+ return True
158
+
159
+
160
+ class LocalStorage(BaseStorage):
161
+ """
162
+ 本地文件存储
163
+ - 使用 aiofiles 进行异步 I/O
164
+ - 使用 asyncio.Lock 进行进程内并发控制
165
+ - 如果需要多进程安全,需要系统级文件锁 (fcntl)
166
+ """
167
+
168
+ def __init__(self):
169
+ self._lock = asyncio.Lock()
170
+
171
+ @asynccontextmanager
172
+ async def acquire_lock(self, name: str, timeout: int = 10):
173
+ if fcntl is None:
174
+ try:
175
+ async with asyncio.timeout(timeout):
176
+ async with self._lock:
177
+ yield
178
+ except asyncio.TimeoutError:
179
+ logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
180
+ raise StorageError(f"无法获取锁 '{name}'")
181
+ return
182
+
183
+ lock_path = LOCK_DIR / f"{name}.lock"
184
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
185
+ fd = None
186
+ locked = False
187
+ start = time.monotonic()
188
+
189
+ async with self._lock:
190
+ try:
191
+ fd = open(lock_path, "a+")
192
+ while True:
193
+ try:
194
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
195
+ locked = True
196
+ break
197
+ except BlockingIOError:
198
+ if time.monotonic() - start >= timeout:
199
+ raise StorageError(f"无法获取锁 '{name}'")
200
+ await asyncio.sleep(0.05)
201
+ yield
202
+ except StorageError:
203
+ logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
204
+ raise
205
+ finally:
206
+ if fd:
207
+ if locked:
208
+ try:
209
+ fcntl.flock(fd, fcntl.LOCK_UN)
210
+ except Exception:
211
+ pass
212
+ try:
213
+ fd.close()
214
+ except Exception:
215
+ pass
216
+
217
+ async def load_config(self) -> Dict[str, Any]:
218
+ if not CONFIG_FILE.exists():
219
+ return {}
220
+ try:
221
+ async with aiofiles.open(CONFIG_FILE, "rb") as f:
222
+ content = await f.read()
223
+ return tomllib.loads(content.decode("utf-8"))
224
+ except Exception as e:
225
+ logger.error(f"LocalStorage: 加载配置失败: {e}")
226
+ return {}
227
+
228
+ async def save_config(self, data: Dict[str, Any]):
229
+ try:
230
+ lines = []
231
+ for section, items in data.items():
232
+ if not isinstance(items, dict):
233
+ continue
234
+ lines.append(f"[{section}]")
235
+ for key, val in items.items():
236
+ if isinstance(val, bool):
237
+ val_str = "true" if val else "false"
238
+ elif isinstance(val, str):
239
+ escaped = val.replace('"', '\\"')
240
+ val_str = f'"{escaped}"'
241
+ elif isinstance(val, (int, float)):
242
+ val_str = str(val)
243
+ elif isinstance(val, (list, dict)):
244
+ val_str = json_dumps(val)
245
+ else:
246
+ val_str = f'"{str(val)}"'
247
+ lines.append(f"{key} = {val_str}")
248
+ lines.append("")
249
+
250
+ content = "\n".join(lines)
251
+
252
+ CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
253
+ async with aiofiles.open(CONFIG_FILE, "w", encoding="utf-8") as f:
254
+ await f.write(content)
255
+ except Exception as e:
256
+ logger.error(f"LocalStorage: 保存配置失败: {e}")
257
+ raise StorageError(f"保存配置失败: {e}")
258
+
259
+ async def load_tokens(self) -> Dict[str, Any]:
260
+ if not TOKEN_FILE.exists():
261
+ return {}
262
+ try:
263
+ async with aiofiles.open(TOKEN_FILE, "rb") as f:
264
+ content = await f.read()
265
+ return json_loads(content)
266
+ except Exception as e:
267
+ logger.error(f"LocalStorage: 加载 Token 失败: {e}")
268
+ return {}
269
+
270
+ async def save_tokens(self, data: Dict[str, Any]):
271
+ try:
272
+ TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True)
273
+ temp_path = TOKEN_FILE.with_suffix(".tmp")
274
+
275
+ # 原子写操作: 写入临时文件 -> 重命名
276
+ async with aiofiles.open(temp_path, "wb") as f:
277
+ await f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
278
+
279
+ # 使用 os.replace 保证原子性
280
+ os.replace(temp_path, TOKEN_FILE)
281
+
282
+ except Exception as e:
283
+ logger.error(f"LocalStorage: 保存 Token 失败: {e}")
284
+ raise StorageError(f"保存 Token 失败: {e}")
285
+
286
+ async def close(self):
287
+ pass
288
+
289
+
290
+ class RedisStorage(BaseStorage):
291
+ """
292
+ Redis 存储
293
+ - 使用 redis-py 异步客户端 (自带连接池)
294
+ - 支持分布式锁 (redis.lock)
295
+ - 扁平化数据结构优化性能
296
+ """
297
+
298
+ def __init__(self, url: str):
299
+ try:
300
+ from redis import asyncio as aioredis
301
+ except ImportError:
302
+ raise ImportError("需要安装 redis 包: pip install redis")
303
+
304
+ # 显式配置连接池
305
+ # 使用 decode_responses=True 简化字符串处理,但在处理复杂对象时使用 orjson
306
+ self.redis = aioredis.from_url(
307
+ url, decode_responses=True, health_check_interval=30
308
+ )
309
+ self.config_key = "grok2api:config" # Hash: section.key -> value_json
310
+ self.key_pools = "grok2api:pools" # Set: pool_names
311
+ self.prefix_pool_set = "grok2api:pool:" # Set: pool -> token_ids
312
+ self.prefix_token_hash = "grok2api:token:" # Hash: token_id -> token_data
313
+ self.lock_prefix = "grok2api:lock:"
314
+
315
+ @asynccontextmanager
316
+ async def acquire_lock(self, name: str, timeout: int = 10):
317
+ # 使用 Redis 分布式锁
318
+ lock_key = f"{self.lock_prefix}{name}"
319
+ lock = self.redis.lock(lock_key, timeout=timeout, blocking_timeout=5)
320
+ acquired = False
321
+ try:
322
+ acquired = await lock.acquire()
323
+ if not acquired:
324
+ raise StorageError(f"RedisStorage: 无法获取锁 '{name}'")
325
+ yield
326
+ finally:
327
+ if acquired:
328
+ try:
329
+ await lock.release()
330
+ except Exception:
331
+ # 锁可能已过期或被意外释放,忽略异常
332
+ pass
333
+
334
+ async def verify_connection(self) -> bool:
335
+ try:
336
+ return await self.redis.ping()
337
+ except Exception:
338
+ return False
339
+
340
+ async def load_config(self) -> Dict[str, Any]:
341
+ """从 Redis Hash 加载配置"""
342
+ try:
343
+ raw_data = await self.redis.hgetall(self.config_key)
344
+ if not raw_data:
345
+ return None
346
+
347
+ config = {}
348
+ for composite_key, val_str in raw_data.items():
349
+ if "." not in composite_key:
350
+ continue
351
+ section, key = composite_key.split(".", 1)
352
+
353
+ if section not in config:
354
+ config[section] = {}
355
+
356
+ try:
357
+ val = json_loads(val_str)
358
+ except Exception:
359
+ val = val_str
360
+ config[section][key] = val
361
+ return config
362
+ except Exception as e:
363
+ logger.error(f"RedisStorage: 加载配置失败: {e}")
364
+ return None
365
+
366
+ async def save_config(self, data: Dict[str, Any]):
367
+ """保存配置到 Redis Hash"""
368
+ try:
369
+ mapping = {}
370
+ for section, items in data.items():
371
+ if not isinstance(items, dict):
372
+ continue
373
+ for key, val in items.items():
374
+ composite_key = f"{section}.{key}"
375
+ mapping[composite_key] = json_dumps(val)
376
+
377
+ await self.redis.delete(self.config_key)
378
+ if mapping:
379
+ await self.redis.hset(self.config_key, mapping=mapping)
380
+ except Exception as e:
381
+ logger.error(f"RedisStorage: 保存配置失败: {e}")
382
+ raise
383
+
384
+ async def load_tokens(self) -> Dict[str, Any]:
385
+ """加载所有 Token"""
386
+ try:
387
+ pool_names = await self.redis.smembers(self.key_pools)
388
+ if not pool_names:
389
+ return None
390
+
391
+ pools = {}
392
+ async with self.redis.pipeline() as pipe:
393
+ for pool_name in pool_names:
394
+ # 获取该池下所有 Token ID
395
+ pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
396
+ pool_tokens_res = await pipe.execute()
397
+
398
+ # 收集所有 Token ID 以便批量查询
399
+ all_token_ids = []
400
+ pool_map = {} # pool_name -> list[token_id]
401
+
402
+ for i, pool_name in enumerate(pool_names):
403
+ tids = list(pool_tokens_res[i])
404
+ pool_map[pool_name] = tids
405
+ all_token_ids.extend(tids)
406
+
407
+ if not all_token_ids:
408
+ return {name: [] for name in pool_names}
409
+
410
+ # 批量获取 Token 详情 (Hash)
411
+ async with self.redis.pipeline() as pipe:
412
+ for tid in all_token_ids:
413
+ pipe.hgetall(f"{self.prefix_token_hash}{tid}")
414
+ token_data_list = await pipe.execute()
415
+
416
+ # 重组数据结构
417
+ token_lookup = {}
418
+ for i, tid in enumerate(all_token_ids):
419
+ t_data = token_data_list[i]
420
+ if not t_data:
421
+ continue
422
+
423
+ # 恢复 tags (JSON -> List)
424
+ if "tags" in t_data:
425
+ try:
426
+ t_data["tags"] = json_loads(t_data["tags"])
427
+ except Exception:
428
+ t_data["tags"] = []
429
+
430
+ # 类型转换 (Redis 返回全 string)
431
+ for int_field in [
432
+ "quota",
433
+ "created_at",
434
+ "use_count",
435
+ "fail_count",
436
+ "last_used_at",
437
+ "last_fail_at",
438
+ "last_sync_at",
439
+ ]:
440
+ if t_data.get(int_field) and t_data[int_field] != "None":
441
+ try:
442
+ t_data[int_field] = int(t_data[int_field])
443
+ except Exception:
444
+ pass
445
+
446
+ token_lookup[tid] = t_data
447
+
448
+ # 按 Pool 分组返回
449
+ for pool_name in pool_names:
450
+ pools[pool_name] = []
451
+ for tid in pool_map[pool_name]:
452
+ if tid in token_lookup:
453
+ pools[pool_name].append(token_lookup[tid])
454
+
455
+ return pools
456
+
457
+ except Exception as e:
458
+ logger.error(f"RedisStorage: 加载 Token 失败: {e}")
459
+ return None
460
+
461
+ async def save_tokens(self, data: Dict[str, Any]):
462
+ """保存所有 Token"""
463
+ if data is None:
464
+ return
465
+ try:
466
+ new_pools = set(data.keys()) if isinstance(data, dict) else set()
467
+ pool_tokens_map = {}
468
+ new_token_ids = set()
469
+
470
+ for pool_name, tokens in (data or {}).items():
471
+ tids_in_pool = []
472
+ for t in tokens:
473
+ token_str = t.get("token")
474
+ if not token_str:
475
+ continue
476
+ tids_in_pool.append(token_str)
477
+ new_token_ids.add(token_str)
478
+ pool_tokens_map[pool_name] = tids_in_pool
479
+
480
+ existing_pools = await self.redis.smembers(self.key_pools)
481
+ existing_pools = set(existing_pools) if existing_pools else set()
482
+
483
+ existing_token_ids = set()
484
+ if existing_pools:
485
+ async with self.redis.pipeline() as pipe:
486
+ for pool_name in existing_pools:
487
+ pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
488
+ pool_tokens_res = await pipe.execute()
489
+ for tokens in pool_tokens_res:
490
+ existing_token_ids.update(list(tokens or []))
491
+
492
+ tokens_to_delete = existing_token_ids - new_token_ids
493
+ all_pools = existing_pools.union(new_pools)
494
+
495
+ async with self.redis.pipeline() as pipe:
496
+ # Reset pool index
497
+ pipe.delete(self.key_pools)
498
+ if new_pools:
499
+ pipe.sadd(self.key_pools, *new_pools)
500
+
501
+ # Reset pool sets
502
+ for pool_name in all_pools:
503
+ pipe.delete(f"{self.prefix_pool_set}{pool_name}")
504
+ for pool_name, tids_in_pool in pool_tokens_map.items():
505
+ if tids_in_pool:
506
+ pipe.sadd(f"{self.prefix_pool_set}{pool_name}", *tids_in_pool)
507
+
508
+ # Remove deleted token hashes
509
+ for token_str in tokens_to_delete:
510
+ pipe.delete(f"{self.prefix_token_hash}{token_str}")
511
+
512
+ # Upsert token hashes
513
+ for pool_name, tokens in (data or {}).items():
514
+ for t in tokens:
515
+ token_str = t.get("token")
516
+ if not token_str:
517
+ continue
518
+ t_flat = t.copy()
519
+ if "tags" in t_flat:
520
+ t_flat["tags"] = json_dumps(t_flat["tags"])
521
+ status = t_flat.get("status")
522
+ if isinstance(status, str) and status.startswith(
523
+ "TokenStatus."
524
+ ):
525
+ t_flat["status"] = status.split(".", 1)[1].lower()
526
+ elif isinstance(status, Enum):
527
+ t_flat["status"] = status.value
528
+ t_flat = {k: str(v) for k, v in t_flat.items() if v is not None}
529
+ pipe.hset(
530
+ f"{self.prefix_token_hash}{token_str}", mapping=t_flat
531
+ )
532
+
533
+ await pipe.execute()
534
+
535
+ except Exception as e:
536
+ logger.error(f"RedisStorage: 保存 Token 失败: {e}")
537
+ raise
538
+
539
+ async def close(self):
540
+ try:
541
+ await self.redis.close()
542
+ except (RuntimeError, asyncio.CancelledError, Exception):
543
+ # 忽略关闭时的 Event loop is closed 错误
544
+ pass
545
+
546
+
547
+ class SQLStorage(BaseStorage):
548
+ """
549
+ SQL 数据库存储 (MySQL/PgSQL)
550
+ - 使用 SQLAlchemy 异步引擎
551
+ - 自动 Schema 初始化
552
+ - 内置连接池 (QueuePool)
553
+ """
554
+
555
+ def __init__(self, url: str, connect_args: dict | None = None):
556
+ try:
557
+ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
558
+ except ImportError:
559
+ raise ImportError(
560
+ "需要安装 sqlalchemy 和 async 驱动: pip install sqlalchemy[asyncio]"
561
+ )
562
+
563
+ self.dialect = url.split(":", 1)[0].split("+", 1)[0].lower()
564
+
565
+ # 配置 robust 的连接池
566
+ self.engine = create_async_engine(
567
+ url,
568
+ echo=False,
569
+ pool_size=20,
570
+ max_overflow=10,
571
+ pool_recycle=3600,
572
+ pool_pre_ping=True,
573
+ **({"connect_args": connect_args} if connect_args else {}),
574
+ )
575
+ self.async_session = async_sessionmaker(self.engine, expire_on_commit=False)
576
+ self._initialized = False
577
+
578
+ async def _ensure_schema(self):
579
+ """确保数据库表存在"""
580
+ if self._initialized:
581
+ return
582
+ try:
583
+ async with self.engine.begin() as conn:
584
+ from sqlalchemy import text
585
+
586
+ # Tokens 表 (通用 SQL)
587
+ await conn.execute(
588
+ text("""
589
+ CREATE TABLE IF NOT EXISTS tokens (
590
+ token VARCHAR(512) PRIMARY KEY,
591
+ pool_name VARCHAR(64) NOT NULL,
592
+ status VARCHAR(16),
593
+ quota INT,
594
+ created_at BIGINT,
595
+ last_used_at BIGINT,
596
+ use_count INT,
597
+ fail_count INT,
598
+ last_fail_at BIGINT,
599
+ last_fail_reason TEXT,
600
+ last_sync_at BIGINT,
601
+ tags TEXT,
602
+ note TEXT,
603
+ last_asset_clear_at BIGINT,
604
+ data TEXT,
605
+ data_hash CHAR(64),
606
+ updated_at BIGINT
607
+ )
608
+ """)
609
+ )
610
+
611
+ # 配置表
612
+ await conn.execute(
613
+ text("""
614
+ CREATE TABLE IF NOT EXISTS app_config (
615
+ section VARCHAR(64) NOT NULL,
616
+ key_name VARCHAR(64) NOT NULL,
617
+ value TEXT,
618
+ PRIMARY KEY (section, key_name)
619
+ )
620
+ """)
621
+ )
622
+
623
+ # 索引
624
+ if self.dialect in ("postgres", "postgresql", "pgsql"):
625
+ await conn.execute(
626
+ text(
627
+ "CREATE INDEX IF NOT EXISTS idx_tokens_pool ON tokens (pool_name)"
628
+ )
629
+ )
630
+ else:
631
+ try:
632
+ await conn.execute(
633
+ text("CREATE INDEX idx_tokens_pool ON tokens (pool_name)")
634
+ )
635
+ except Exception:
636
+ pass
637
+
638
+ # 补齐旧表字段
639
+ columns = [
640
+ ("status", "VARCHAR(16)"),
641
+ ("quota", "INT"),
642
+ ("created_at", "BIGINT"),
643
+ ("last_used_at", "BIGINT"),
644
+ ("use_count", "INT"),
645
+ ("fail_count", "INT"),
646
+ ("last_fail_at", "BIGINT"),
647
+ ("last_fail_reason", "TEXT"),
648
+ ("last_sync_at", "BIGINT"),
649
+ ("tags", "TEXT"),
650
+ ("note", "TEXT"),
651
+ ("last_asset_clear_at", "BIGINT"),
652
+ ("data", "TEXT"),
653
+ ("data_hash", "CHAR(64)"),
654
+ ("updated_at", "BIGINT"),
655
+ ]
656
+ if self.dialect in ("postgres", "postgresql", "pgsql"):
657
+ for col_name, col_type in columns:
658
+ await conn.execute(
659
+ text(
660
+ f"ALTER TABLE tokens ADD COLUMN IF NOT EXISTS {col_name} {col_type}"
661
+ )
662
+ )
663
+ else:
664
+ for col_name, col_type in columns:
665
+ try:
666
+ await conn.execute(
667
+ text(
668
+ f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}"
669
+ )
670
+ )
671
+ except Exception:
672
+ pass
673
+
674
+ # 尝试兼容旧表结构
675
+ try:
676
+ if self.dialect in ("mysql", "mariadb"):
677
+ await conn.execute(
678
+ text("ALTER TABLE tokens MODIFY token VARCHAR(512)")
679
+ )
680
+ await conn.execute(text("ALTER TABLE tokens MODIFY data TEXT"))
681
+ elif self.dialect in ("postgres", "postgresql", "pgsql"):
682
+ await conn.execute(
683
+ text(
684
+ "ALTER TABLE tokens ALTER COLUMN token TYPE VARCHAR(512)"
685
+ )
686
+ )
687
+ await conn.execute(
688
+ text("ALTER TABLE tokens ALTER COLUMN data TYPE TEXT")
689
+ )
690
+ except Exception:
691
+ pass
692
+
693
+ await self._migrate_legacy_tokens()
694
+ self._initialized = True
695
+ except Exception as e:
696
+ logger.error(f"SQLStorage: Schema 初始化失败: {e}")
697
+ raise
698
+
699
+ def _normalize_status(self, status: Any) -> Any:
700
+ if isinstance(status, str) and status.startswith("TokenStatus."):
701
+ return status.split(".", 1)[1].lower()
702
+ if isinstance(status, Enum):
703
+ return status.value
704
+ return status
705
+
706
+ def _normalize_tags(self, tags: Any) -> Optional[str]:
707
+ if tags is None:
708
+ return None
709
+ if isinstance(tags, str):
710
+ try:
711
+ parsed = json_loads(tags)
712
+ if isinstance(parsed, list):
713
+ return tags
714
+ except Exception:
715
+ pass
716
+ return json_dumps([tags])
717
+ return json_dumps(tags)
718
+
719
+ def _parse_tags(self, tags: Any) -> Optional[list]:
720
+ if tags is None:
721
+ return None
722
+ if isinstance(tags, str):
723
+ try:
724
+ parsed = json_loads(tags)
725
+ if isinstance(parsed, list):
726
+ return parsed
727
+ except Exception:
728
+ return []
729
+ if isinstance(tags, list):
730
+ return tags
731
+ return []
732
+
733
+ def _token_to_row(self, token_data: Dict[str, Any], pool_name: str) -> Dict[str, Any]:
734
+ token_str = token_data.get("token")
735
+ if isinstance(token_str, str) and token_str.startswith("sso="):
736
+ token_str = token_str[4:]
737
+
738
+ status = self._normalize_status(token_data.get("status"))
739
+ tags_json = self._normalize_tags(token_data.get("tags"))
740
+ data_json = json_dumps_sorted(token_data)
741
+ data_hash = hashlib.sha256(data_json.encode("utf-8")).hexdigest()
742
+ note = token_data.get("note")
743
+ if note is None:
744
+ note = ""
745
+
746
+ return {
747
+ "token": token_str,
748
+ "pool_name": pool_name,
749
+ "status": status,
750
+ "quota": token_data.get("quota"),
751
+ "created_at": token_data.get("created_at"),
752
+ "last_used_at": token_data.get("last_used_at"),
753
+ "use_count": token_data.get("use_count"),
754
+ "fail_count": token_data.get("fail_count"),
755
+ "last_fail_at": token_data.get("last_fail_at"),
756
+ "last_fail_reason": token_data.get("last_fail_reason"),
757
+ "last_sync_at": token_data.get("last_sync_at"),
758
+ "tags": tags_json,
759
+ "note": note,
760
+ "last_asset_clear_at": token_data.get("last_asset_clear_at"),
761
+ "data": data_json,
762
+ "data_hash": data_hash,
763
+ "updated_at": 0,
764
+ }
765
+
766
+ async def _migrate_legacy_tokens(self):
767
+ """将旧版 data JSON 回填到平铺字段"""
768
+ from sqlalchemy import text
769
+
770
+ try:
771
+ async with self.async_session() as session:
772
+ try:
773
+ res = await session.execute(
774
+ text(
775
+ "SELECT token FROM tokens "
776
+ "WHERE data IS NOT NULL AND "
777
+ "(status IS NULL OR quota IS NULL OR created_at IS NULL) "
778
+ "LIMIT 1"
779
+ )
780
+ )
781
+ if not res.first():
782
+ return
783
+ except Exception as e:
784
+ msg = str(e).lower()
785
+ if "undefinedcolumn" in msg or "undefined column" in msg:
786
+ return
787
+ raise
788
+
789
+ res = await session.execute(
790
+ text(
791
+ "SELECT token, pool_name, data FROM tokens "
792
+ "WHERE data IS NOT NULL AND "
793
+ "(status IS NULL OR quota IS NULL OR created_at IS NULL)"
794
+ )
795
+ )
796
+ rows = res.fetchall()
797
+ if not rows:
798
+ return
799
+
800
+ params = []
801
+ for token_str, pool_name, data_json in rows:
802
+ if not data_json:
803
+ continue
804
+ try:
805
+ if isinstance(data_json, str):
806
+ t_data = json_loads(data_json)
807
+ else:
808
+ t_data = data_json
809
+ if not isinstance(t_data, dict):
810
+ continue
811
+ t_data = dict(t_data)
812
+ t_data["token"] = token_str
813
+ row = self._token_to_row(t_data, pool_name)
814
+ params.append(row)
815
+ except Exception:
816
+ continue
817
+
818
+ if not params:
819
+ return
820
+
821
+ await session.execute(
822
+ text(
823
+ "UPDATE tokens SET "
824
+ "pool_name=:pool_name, "
825
+ "status=:status, "
826
+ "quota=:quota, "
827
+ "created_at=:created_at, "
828
+ "last_used_at=:last_used_at, "
829
+ "use_count=:use_count, "
830
+ "fail_count=:fail_count, "
831
+ "last_fail_at=:last_fail_at, "
832
+ "last_fail_reason=:last_fail_reason, "
833
+ "last_sync_at=:last_sync_at, "
834
+ "tags=:tags, "
835
+ "note=:note, "
836
+ "last_asset_clear_at=:last_asset_clear_at, "
837
+ "data=:data, "
838
+ "data_hash=:data_hash, "
839
+ "updated_at=:updated_at "
840
+ "WHERE token=:token"
841
+ ),
842
+ params,
843
+ )
844
+ await session.commit()
845
+ except Exception as e:
846
+ logger.warning(f"SQLStorage: 旧数据回填失败: {e}")
847
+
848
+ @asynccontextmanager
849
+ async def acquire_lock(self, name: str, timeout: int = 10):
850
+ # SQL 分布式锁: MySQL GET_LOCK / PG advisory_lock
851
+ from sqlalchemy import text
852
+
853
+ lock_name = f"g2a:{hashlib.sha1(name.encode('utf-8')).hexdigest()[:24]}"
854
+ if self.dialect in ("mysql", "mariadb"):
855
+ async with self.async_session() as session:
856
+ res = await session.execute(
857
+ text("SELECT GET_LOCK(:name, :timeout)"),
858
+ {"name": lock_name, "timeout": timeout},
859
+ )
860
+ got = res.scalar()
861
+ if got != 1:
862
+ raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
863
+ try:
864
+ yield
865
+ finally:
866
+ try:
867
+ await session.execute(
868
+ text("SELECT RELEASE_LOCK(:name)"), {"name": lock_name}
869
+ )
870
+ await session.commit()
871
+ except Exception:
872
+ pass
873
+ elif self.dialect in ("postgres", "postgresql", "pgsql"):
874
+ lock_key = int.from_bytes(
875
+ hashlib.sha256(name.encode("utf-8")).digest()[:8], "big", signed=True
876
+ )
877
+ async with self.async_session() as session:
878
+ start = time.monotonic()
879
+ while True:
880
+ res = await session.execute(
881
+ text("SELECT pg_try_advisory_lock(:key)"), {"key": lock_key}
882
+ )
883
+ if res.scalar():
884
+ break
885
+ if time.monotonic() - start >= timeout:
886
+ raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
887
+ await asyncio.sleep(0.1)
888
+ try:
889
+ yield
890
+ finally:
891
+ try:
892
+ await session.execute(
893
+ text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key}
894
+ )
895
+ await session.commit()
896
+ except Exception:
897
+ pass
898
+ else:
899
+ yield
900
+
901
+ async def load_config(self) -> Dict[str, Any]:
902
+ await self._ensure_schema()
903
+ from sqlalchemy import text
904
+
905
+ try:
906
+ async with self.async_session() as session:
907
+ res = await session.execute(
908
+ text("SELECT section, key_name, value FROM app_config")
909
+ )
910
+ rows = res.fetchall()
911
+ if not rows:
912
+ return None
913
+
914
+ config = {}
915
+ for section, key, val_str in rows:
916
+ if section not in config:
917
+ config[section] = {}
918
+ try:
919
+ val = json_loads(val_str)
920
+ except Exception:
921
+ val = val_str
922
+ config[section][key] = val
923
+ return config
924
+ except Exception as e:
925
+ logger.error(f"SQLStorage: 加载配置失败: {e}")
926
+ return None
927
+
928
+ async def save_config(self, data: Dict[str, Any]):
929
+ await self._ensure_schema()
930
+ from sqlalchemy import text
931
+
932
+ try:
933
+ async with self.async_session() as session:
934
+ await session.execute(text("DELETE FROM app_config"))
935
+
936
+ params = []
937
+ for section, items in data.items():
938
+ if not isinstance(items, dict):
939
+ continue
940
+ for key, val in items.items():
941
+ params.append(
942
+ {
943
+ "s": section,
944
+ "k": key,
945
+ "v": json_dumps(val),
946
+ }
947
+ )
948
+
949
+ if params:
950
+ await session.execute(
951
+ text(
952
+ "INSERT INTO app_config (section, key_name, value) VALUES (:s, :k, :v)"
953
+ ),
954
+ params,
955
+ )
956
+ await session.commit()
957
+ except Exception as e:
958
+ logger.error(f"SQLStorage: 保存配置失败: {e}")
959
+ raise
960
+
961
+ async def load_tokens(self) -> Dict[str, Any]:
962
+ await self._ensure_schema()
963
+ from sqlalchemy import text
964
+
965
+ try:
966
+ async with self.async_session() as session:
967
+ res = await session.execute(
968
+ text(
969
+ "SELECT token, pool_name, status, quota, created_at, "
970
+ "last_used_at, use_count, fail_count, last_fail_at, "
971
+ "last_fail_reason, last_sync_at, tags, note, "
972
+ "last_asset_clear_at, data "
973
+ "FROM tokens"
974
+ )
975
+ )
976
+ rows = res.fetchall()
977
+ if not rows:
978
+ return None
979
+
980
+ pools = {}
981
+ for (
982
+ token_str,
983
+ pool_name,
984
+ status,
985
+ quota,
986
+ created_at,
987
+ last_used_at,
988
+ use_count,
989
+ fail_count,
990
+ last_fail_at,
991
+ last_fail_reason,
992
+ last_sync_at,
993
+ tags,
994
+ note,
995
+ last_asset_clear_at,
996
+ data_json,
997
+ ) in rows:
998
+ if pool_name not in pools:
999
+ pools[pool_name] = []
1000
+
1001
+ try:
1002
+ token_data = {}
1003
+ if token_str:
1004
+ token_data["token"] = token_str
1005
+ if status is not None:
1006
+ token_data["status"] = self._normalize_status(status)
1007
+ if quota is not None:
1008
+ token_data["quota"] = int(quota)
1009
+ if created_at is not None:
1010
+ token_data["created_at"] = int(created_at)
1011
+ if last_used_at is not None:
1012
+ token_data["last_used_at"] = int(last_used_at)
1013
+ if use_count is not None:
1014
+ token_data["use_count"] = int(use_count)
1015
+ if fail_count is not None:
1016
+ token_data["fail_count"] = int(fail_count)
1017
+ if last_fail_at is not None:
1018
+ token_data["last_fail_at"] = int(last_fail_at)
1019
+ if last_fail_reason is not None:
1020
+ token_data["last_fail_reason"] = last_fail_reason
1021
+ if last_sync_at is not None:
1022
+ token_data["last_sync_at"] = int(last_sync_at)
1023
+ if tags is not None:
1024
+ token_data["tags"] = self._parse_tags(tags)
1025
+ if note is not None:
1026
+ token_data["note"] = note
1027
+ if last_asset_clear_at is not None:
1028
+ token_data["last_asset_clear_at"] = int(
1029
+ last_asset_clear_at
1030
+ )
1031
+
1032
+ legacy_data = None
1033
+ if data_json:
1034
+ if isinstance(data_json, str):
1035
+ legacy_data = json_loads(data_json)
1036
+ else:
1037
+ legacy_data = data_json
1038
+ if isinstance(legacy_data, dict):
1039
+ for key, val in legacy_data.items():
1040
+ if key not in token_data or token_data[key] is None:
1041
+ token_data[key] = val
1042
+
1043
+ pools[pool_name].append(token_data)
1044
+ except Exception:
1045
+ pass
1046
+ return pools
1047
+ except Exception as e:
1048
+ logger.error(f"SQLStorage: 加载 Token 失败: {e}")
1049
+ return None
1050
+
1051
+ async def save_tokens(self, data: Dict[str, Any]):
1052
+ await self._ensure_schema()
1053
+ from sqlalchemy import text
1054
+
1055
+ if data is None:
1056
+ return
1057
+
1058
+ updates = []
1059
+ new_tokens = set()
1060
+ for pool_name, tokens in (data or {}).items():
1061
+ for t in tokens:
1062
+ if isinstance(t, dict):
1063
+ token_data = dict(t)
1064
+ elif isinstance(t, str):
1065
+ token_data = {"token": t}
1066
+ else:
1067
+ continue
1068
+ token_str = token_data.get("token")
1069
+ if not token_str:
1070
+ continue
1071
+ if token_str.startswith("sso="):
1072
+ token_str = token_str[4:]
1073
+ token_data["token"] = token_str
1074
+ token_data["pool_name"] = pool_name
1075
+ token_data["_update_kind"] = "state"
1076
+ updates.append(token_data)
1077
+ new_tokens.add(token_str)
1078
+
1079
+ try:
1080
+ existing_tokens = set()
1081
+ async with self.async_session() as session:
1082
+ res = await session.execute(text("SELECT token FROM tokens"))
1083
+ rows = res.fetchall()
1084
+ existing_tokens = {row[0] for row in rows}
1085
+ tokens_to_delete = list(existing_tokens - new_tokens)
1086
+ await self.save_tokens_delta(updates, tokens_to_delete)
1087
+ except Exception as e:
1088
+ logger.error(f"SQLStorage: 保存 Token 失败: {e}")
1089
+ raise
1090
+
1091
+ async def save_tokens_delta(
1092
+ self, updated: list[Dict[str, Any]], deleted: Optional[list[str]] = None
1093
+ ):
1094
+ await self._ensure_schema()
1095
+ from sqlalchemy import bindparam, text
1096
+
1097
+ try:
1098
+ async with self.async_session() as session:
1099
+ deleted_set = set(deleted or [])
1100
+ if deleted_set:
1101
+ delete_stmt = text(
1102
+ "DELETE FROM tokens WHERE token IN :tokens"
1103
+ ).bindparams(bindparam("tokens", expanding=True))
1104
+ chunk_size = 500
1105
+ deleted_list = list(deleted_set)
1106
+ for i in range(0, len(deleted_list), chunk_size):
1107
+ chunk = deleted_list[i : i + chunk_size]
1108
+ await session.execute(delete_stmt, {"tokens": chunk})
1109
+
1110
+ updates = []
1111
+ usage_updates = []
1112
+
1113
+ for item in updated or []:
1114
+ if not isinstance(item, dict):
1115
+ continue
1116
+ pool_name = item.get("pool_name")
1117
+ token_str = item.get("token")
1118
+ if not pool_name or not token_str:
1119
+ continue
1120
+ if token_str in deleted_set:
1121
+ continue
1122
+ update_kind = item.get("_update_kind", "state")
1123
+ token_data = {
1124
+ k: v
1125
+ for k, v in item.items()
1126
+ if k not in ("pool_name", "_update_kind")
1127
+ }
1128
+ row = self._token_to_row(token_data, pool_name)
1129
+ if update_kind == "usage":
1130
+ usage_updates.append(row)
1131
+ else:
1132
+ updates.append(row)
1133
+
1134
+ if updates:
1135
+ if self.dialect in ("mysql", "mariadb"):
1136
+ upsert_stmt = text(
1137
+ "INSERT INTO tokens (token, pool_name, status, quota, created_at, "
1138
+ "last_used_at, use_count, fail_count, last_fail_at, "
1139
+ "last_fail_reason, last_sync_at, tags, note, "
1140
+ "last_asset_clear_at, data, data_hash, updated_at) "
1141
+ "VALUES (:token, :pool_name, :status, :quota, :created_at, "
1142
+ ":last_used_at, :use_count, :fail_count, :last_fail_at, "
1143
+ ":last_fail_reason, :last_sync_at, :tags, :note, "
1144
+ ":last_asset_clear_at, :data, :data_hash, :updated_at) "
1145
+ "ON DUPLICATE KEY UPDATE "
1146
+ "pool_name=VALUES(pool_name), "
1147
+ "status=VALUES(status), "
1148
+ "quota=VALUES(quota), "
1149
+ "created_at=VALUES(created_at), "
1150
+ "last_used_at=VALUES(last_used_at), "
1151
+ "use_count=VALUES(use_count), "
1152
+ "fail_count=VALUES(fail_count), "
1153
+ "last_fail_at=VALUES(last_fail_at), "
1154
+ "last_fail_reason=VALUES(last_fail_reason), "
1155
+ "last_sync_at=VALUES(last_sync_at), "
1156
+ "tags=VALUES(tags), "
1157
+ "note=VALUES(note), "
1158
+ "last_asset_clear_at=VALUES(last_asset_clear_at), "
1159
+ "data=VALUES(data), "
1160
+ "data_hash=VALUES(data_hash), "
1161
+ "updated_at=VALUES(updated_at)"
1162
+ )
1163
+ elif self.dialect in ("postgres", "postgresql", "pgsql"):
1164
+ upsert_stmt = text(
1165
+ "INSERT INTO tokens (token, pool_name, status, quota, created_at, "
1166
+ "last_used_at, use_count, fail_count, last_fail_at, "
1167
+ "last_fail_reason, last_sync_at, tags, note, "
1168
+ "last_asset_clear_at, data, data_hash, updated_at) "
1169
+ "VALUES (:token, :pool_name, :status, :quota, :created_at, "
1170
+ ":last_used_at, :use_count, :fail_count, :last_fail_at, "
1171
+ ":last_fail_reason, :last_sync_at, :tags, :note, "
1172
+ ":last_asset_clear_at, :data, :data_hash, :updated_at) "
1173
+ "ON CONFLICT (token) DO UPDATE SET "
1174
+ "pool_name=EXCLUDED.pool_name, "
1175
+ "status=EXCLUDED.status, "
1176
+ "quota=EXCLUDED.quota, "
1177
+ "created_at=EXCLUDED.created_at, "
1178
+ "last_used_at=EXCLUDED.last_used_at, "
1179
+ "use_count=EXCLUDED.use_count, "
1180
+ "fail_count=EXCLUDED.fail_count, "
1181
+ "last_fail_at=EXCLUDED.last_fail_at, "
1182
+ "last_fail_reason=EXCLUDED.last_fail_reason, "
1183
+ "last_sync_at=EXCLUDED.last_sync_at, "
1184
+ "tags=EXCLUDED.tags, "
1185
+ "note=EXCLUDED.note, "
1186
+ "last_asset_clear_at=EXCLUDED.last_asset_clear_at, "
1187
+ "data=EXCLUDED.data, "
1188
+ "data_hash=EXCLUDED.data_hash, "
1189
+ "updated_at=EXCLUDED.updated_at"
1190
+ )
1191
+ else:
1192
+ upsert_stmt = text(
1193
+ "INSERT INTO tokens (token, pool_name, status, quota, created_at, "
1194
+ "last_used_at, use_count, fail_count, last_fail_at, "
1195
+ "last_fail_reason, last_sync_at, tags, note, "
1196
+ "last_asset_clear_at, data, data_hash, updated_at) "
1197
+ "VALUES (:token, :pool_name, :status, :quota, :created_at, "
1198
+ ":last_used_at, :use_count, :fail_count, :last_fail_at, "
1199
+ ":last_fail_reason, :last_sync_at, :tags, :note, "
1200
+ ":last_asset_clear_at, :data, :data_hash, :updated_at)"
1201
+ )
1202
+ await session.execute(upsert_stmt, updates)
1203
+
1204
+ if usage_updates:
1205
+ if self.dialect in ("mysql", "mariadb"):
1206
+ usage_stmt = text(
1207
+ "INSERT INTO tokens (token, pool_name, status, quota, created_at, "
1208
+ "last_used_at, use_count, fail_count, last_fail_at, "
1209
+ "last_fail_reason, last_sync_at, tags, note, "
1210
+ "last_asset_clear_at, data, data_hash, updated_at) "
1211
+ "VALUES (:token, :pool_name, :status, :quota, :created_at, "
1212
+ ":last_used_at, :use_count, :fail_count, :last_fail_at, "
1213
+ ":last_fail_reason, :last_sync_at, :tags, :note, "
1214
+ ":last_asset_clear_at, :data, :data_hash, :updated_at) "
1215
+ "ON DUPLICATE KEY UPDATE "
1216
+ "pool_name=VALUES(pool_name), "
1217
+ "status=VALUES(status), "
1218
+ "quota=VALUES(quota), "
1219
+ "last_used_at=VALUES(last_used_at), "
1220
+ "use_count=VALUES(use_count), "
1221
+ "fail_count=VALUES(fail_count), "
1222
+ "last_fail_at=VALUES(last_fail_at), "
1223
+ "last_fail_reason=VALUES(last_fail_reason), "
1224
+ "last_sync_at=VALUES(last_sync_at), "
1225
+ "updated_at=VALUES(updated_at)"
1226
+ )
1227
+ elif self.dialect in ("postgres", "postgresql", "pgsql"):
1228
+ usage_stmt = text(
1229
+ "INSERT INTO tokens (token, pool_name, status, quota, created_at, "
1230
+ "last_used_at, use_count, fail_count, last_fail_at, "
1231
+ "last_fail_reason, last_sync_at, tags, note, "
1232
+ "last_asset_clear_at, data, data_hash, updated_at) "
1233
+ "VALUES (:token, :pool_name, :status, :quota, :created_at, "
1234
+ ":last_used_at, :use_count, :fail_count, :last_fail_at, "
1235
+ ":last_fail_reason, :last_sync_at, :tags, :note, "
1236
+ ":last_asset_clear_at, :data, :data_hash, :updated_at) "
1237
+ "ON CONFLICT (token) DO UPDATE SET "
1238
+ "pool_name=EXCLUDED.pool_name, "
1239
+ "status=EXCLUDED.status, "
1240
+ "quota=EXCLUDED.quota, "
1241
+ "last_used_at=EXCLUDED.last_used_at, "
1242
+ "use_count=EXCLUDED.use_count, "
1243
+ "fail_count=EXCLUDED.fail_count, "
1244
+ "last_fail_at=EXCLUDED.last_fail_at, "
1245
+ "last_fail_reason=EXCLUDED.last_fail_reason, "
1246
+ "last_sync_at=EXCLUDED.last_sync_at, "
1247
+ "updated_at=EXCLUDED.updated_at"
1248
+ )
1249
+ else:
1250
+ usage_stmt = text(
1251
+ "INSERT INTO tokens (token, pool_name, status, quota, created_at, "
1252
+ "last_used_at, use_count, fail_count, last_fail_at, "
1253
+ "last_fail_reason, last_sync_at, tags, note, "
1254
+ "last_asset_clear_at, data, data_hash, updated_at) "
1255
+ "VALUES (:token, :pool_name, :status, :quota, :created_at, "
1256
+ ":last_used_at, :use_count, :fail_count, :last_fail_at, "
1257
+ ":last_fail_reason, :last_sync_at, :tags, :note, "
1258
+ ":last_asset_clear_at, :data, :data_hash, :updated_at)"
1259
+ )
1260
+ await session.execute(usage_stmt, usage_updates)
1261
+
1262
+ await session.commit()
1263
+ except Exception as e:
1264
+ logger.error(f"SQLStorage: 增量保存 Token 失败: {e}")
1265
+ raise
1266
+
1267
+ async def close(self):
1268
+ await self.engine.dispose()
1269
+
1270
+
1271
+ class StorageFactory:
1272
+ """存储后端工厂"""
1273
+
1274
+ _instance: Optional[BaseStorage] = None
1275
+
1276
+ # SSL-related query parameters that async drivers (asyncpg, aiomysql)
1277
+ # cannot accept via the URL and must be passed as connect_args instead.
1278
+ _SQL_SSL_PARAM_KEYS = ("sslmode", "ssl-mode", "ssl")
1279
+
1280
+ # Canonical postgres ssl modes (asyncpg accepts libpq-style mode strings).
1281
+ _PG_SSL_MODE_ALIASES: ClassVar[dict[str, str]] = {
1282
+ "disable": "disable",
1283
+ "disabled": "disable",
1284
+ "false": "disable",
1285
+ "0": "disable",
1286
+ "no": "disable",
1287
+ "off": "disable",
1288
+ "prefer": "prefer",
1289
+ "preferred": "prefer",
1290
+ "allow": "allow",
1291
+ "require": "require",
1292
+ "required": "require",
1293
+ "true": "require",
1294
+ "1": "require",
1295
+ "yes": "require",
1296
+ "on": "require",
1297
+ "verify-ca": "verify-ca",
1298
+ "verify_ca": "verify-ca",
1299
+ "verify-full": "verify-full",
1300
+ "verify_full": "verify-full",
1301
+ "verify-identity": "verify-full",
1302
+ "verify_identity": "verify-full",
1303
+ }
1304
+
1305
+ # Canonical mysql ssl modes (aiomysql accepts SSLContext, not mode strings).
1306
+ _MY_SSL_MODE_ALIASES: ClassVar[dict[str, str]] = {
1307
+ "disable": "disabled",
1308
+ "disabled": "disabled",
1309
+ "false": "disabled",
1310
+ "0": "disabled",
1311
+ "no": "disabled",
1312
+ "off": "disabled",
1313
+ "prefer": "preferred",
1314
+ "preferred": "preferred",
1315
+ "allow": "preferred",
1316
+ "require": "required",
1317
+ "required": "required",
1318
+ "true": "required",
1319
+ "1": "required",
1320
+ "yes": "required",
1321
+ "on": "required",
1322
+ "verify-ca": "verify_ca",
1323
+ "verify_ca": "verify_ca",
1324
+ "verify-full": "verify_identity",
1325
+ "verify_full": "verify_identity",
1326
+ "verify-identity": "verify_identity",
1327
+ "verify_identity": "verify_identity",
1328
+ }
1329
+
1330
+ @classmethod
1331
+ def _normalize_ssl_mode(cls, storage_type: str, mode: str) -> str:
1332
+ """Normalize SSL mode aliases for the target storage backend."""
1333
+ if not mode:
1334
+ raise ValueError("SSL mode cannot be empty")
1335
+
1336
+ normalized = mode.strip().lower().replace(" ", "")
1337
+ if storage_type == "pgsql":
1338
+ canonical = cls._PG_SSL_MODE_ALIASES.get(normalized)
1339
+ elif storage_type == "mysql":
1340
+ canonical = cls._MY_SSL_MODE_ALIASES.get(normalized)
1341
+ else:
1342
+ canonical = None
1343
+
1344
+ if not canonical:
1345
+ raise ValueError(
1346
+ f"Unsupported SSL mode '{mode}' for storage type '{storage_type}'"
1347
+ )
1348
+ return canonical
1349
+
1350
+ @classmethod
1351
+ def _build_mysql_ssl_context(cls, mode: str):
1352
+ """Build SSLContext for aiomysql according to normalized mysql mode.
1353
+
1354
+ Note: aiomysql enforces SSL whenever an SSLContext is provided — there
1355
+ is no "try SSL, fall back to plaintext" behaviour. As a result the
1356
+ ``preferred`` mode is treated identically to ``required`` (encrypted,
1357
+ no cert verification). Connections to MySQL servers that do not
1358
+ support SSL will fail rather than degrade gracefully.
1359
+ """
1360
+ import ssl as _ssl
1361
+
1362
+ if mode == "disabled":
1363
+ return None
1364
+
1365
+ ctx = _ssl.create_default_context()
1366
+ if mode in ("preferred", "required"):
1367
+ ctx.check_hostname = False
1368
+ ctx.verify_mode = _ssl.CERT_NONE
1369
+ elif mode == "verify_ca":
1370
+ # verify CA, but do not enforce hostname match.
1371
+ ctx.check_hostname = False
1372
+ # verify_identity keeps defaults: verify cert + hostname.
1373
+ return ctx
1374
+
1375
+ @classmethod
1376
+ def _build_sql_connect_args(
1377
+ cls, storage_type: str, raw_ssl_mode: Optional[str]
1378
+ ) -> Optional[dict]:
1379
+ """Build SQLAlchemy connect_args for SQL SSL modes."""
1380
+ if not raw_ssl_mode:
1381
+ return None
1382
+
1383
+ mode = cls._normalize_ssl_mode(storage_type, raw_ssl_mode)
1384
+ if storage_type == "pgsql":
1385
+ # asyncpg accepts libpq-style ssl mode strings via ssl=...
1386
+ return {"ssl": mode}
1387
+ if storage_type == "mysql":
1388
+ ctx = cls._build_mysql_ssl_context(mode)
1389
+ if ctx is None:
1390
+ return None
1391
+ return {"ssl": ctx}
1392
+ return None
1393
+
1394
+ @classmethod
1395
+ def _normalize_sql_url(cls, storage_type: str, url: str) -> str:
1396
+ """Rewrite scheme prefix to the SQLAlchemy async dialect form."""
1397
+ if not url or "://" not in url:
1398
+ return url
1399
+ if storage_type == "mysql":
1400
+ if url.startswith("mysql://"):
1401
+ url = f"mysql+aiomysql://{url[len('mysql://') :]}"
1402
+ elif url.startswith("mariadb://"):
1403
+ # Use mysql+aiomysql for both MySQL and MariaDB endpoints.
1404
+ # The mariadb dialect enforces strict MariaDB server detection.
1405
+ url = f"mysql+aiomysql://{url[len('mariadb://') :]}"
1406
+ elif url.startswith("mariadb+aiomysql://"):
1407
+ url = f"mysql+aiomysql://{url[len('mariadb+aiomysql://') :]}"
1408
+ elif storage_type == "pgsql":
1409
+ if url.startswith("postgres://"):
1410
+ url = f"postgresql+asyncpg://{url[len('postgres://') :]}"
1411
+ elif url.startswith("postgresql://"):
1412
+ url = f"postgresql+asyncpg://{url[len('postgresql://') :]}"
1413
+ elif url.startswith("pgsql://"):
1414
+ url = f"postgresql+asyncpg://{url[len('pgsql://') :]}"
1415
+ return url
1416
+
1417
+ @classmethod
1418
+ def _prepare_sql_url_and_connect_args(
1419
+ cls, storage_type: str, url: str
1420
+ ) -> tuple[str, Optional[dict]]:
1421
+ """Normalize SQL URL and build connect_args from SSL query params."""
1422
+ from urllib.parse import urlparse, parse_qsl, urlencode, urlunparse
1423
+
1424
+ normalized_url = cls._normalize_sql_url(storage_type, url)
1425
+ if "://" not in normalized_url:
1426
+ return normalized_url, None
1427
+
1428
+ parsed = urlparse(normalized_url)
1429
+ ssl_mode: Optional[str] = None
1430
+ filtered_query_items = []
1431
+ ssl_param_keys = {k.lower() for k in cls._SQL_SSL_PARAM_KEYS}
1432
+ for key, value in parse_qsl(parsed.query, keep_blank_values=True):
1433
+ if key.lower() in ssl_param_keys:
1434
+ if ssl_mode is None and value:
1435
+ ssl_mode = value
1436
+ continue
1437
+ filtered_query_items.append((key, value))
1438
+
1439
+ cleaned_url = urlunparse(
1440
+ parsed._replace(query=urlencode(filtered_query_items, doseq=True))
1441
+ )
1442
+ connect_args = cls._build_sql_connect_args(storage_type, ssl_mode)
1443
+ return cleaned_url, connect_args
1444
+
1445
+ @classmethod
1446
+ def get_storage(cls) -> BaseStorage:
1447
+ """获取全局存储实例 (单例)"""
1448
+ if cls._instance:
1449
+ return cls._instance
1450
+
1451
+ storage_type = os.getenv("SERVER_STORAGE_TYPE", "local").lower()
1452
+ storage_url = os.getenv("SERVER_STORAGE_URL", "")
1453
+
1454
+ logger.info(f"StorageFactory: 初始化存储后端: {storage_type}")
1455
+
1456
+ if storage_type == "redis":
1457
+ if not storage_url:
1458
+ raise ValueError("Redis 存储需要设置 SERVER_STORAGE_URL")
1459
+ cls._instance = RedisStorage(storage_url)
1460
+
1461
+ elif storage_type in ("mysql", "pgsql"):
1462
+ if not storage_url:
1463
+ raise ValueError("SQL 存储需要设置 SERVER_STORAGE_URL")
1464
+ # Drivers reject SSL query params in URL. Normalize URL and pass
1465
+ # backend-specific SSL handling through connect_args.
1466
+ storage_url, connect_args = cls._prepare_sql_url_and_connect_args(
1467
+ storage_type, storage_url
1468
+ )
1469
+ cls._instance = SQLStorage(storage_url, connect_args=connect_args)
1470
+
1471
+ else:
1472
+ cls._instance = LocalStorage()
1473
+
1474
+ return cls._instance
1475
+
1476
+
1477
+ def get_storage() -> BaseStorage:
1478
+ return StorageFactory.get_storage()
app/services/cf_refresh/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cf_refresh - Cloudflare cf_clearance 自动刷新
2
+
3
+ 通过 [FlareSolverr](https://github.com/FlareSolverr/FlareSolverr) 自动获取 Cloudflare `cf_clearance` cookie 和 `user_agent`,并更新到 Grok2API 服务配置中。
4
+
5
+ 全自动、无需 GUI、服务器友好。
6
+
7
+ ## 工作原理
8
+
9
+ 1. FlareSolverr(独立 Docker 容器)内部运行 Chrome,自动通过 CF 挑战
10
+ 2. cf_refresh 作为 grok2api 的后台任务,调用 FlareSolverr HTTP API 获取 `cf_clearance` 和 `user_agent`
11
+ 3. 直接在进程内调用 `config.update()` 更新运行时配置并持久化到 `data/config.toml`
12
+ 4. 按设定间隔重复以上步骤
13
+
14
+ ## 配置方式
15
+
16
+ 所有配置均可在管理面板 `/admin/config` 的 **CF 自动刷新** 区域中设置,也可通过环境变量初始化:
17
+
18
+ | 配置项 | 环境变量 | 默认值 | 说明 |
19
+ |--------|----------|--------|------|
20
+ | 启用自动刷新 | `FLARESOLVERR_URL`(非空即启用) | `false` | 是否开启自动刷新 |
21
+ | FlareSolverr 地址 | `FLARESOLVERR_URL` | — | FlareSolverr 服务的 HTTP 地址 |
22
+ | 刷新间隔(秒) | `CF_REFRESH_INTERVAL` | `600` | 定期刷新间隔 |
23
+ | 挑战超时(秒) | `CF_TIMEOUT` | `60` | CF 挑战等待超时 |
24
+
25
+ > **代理**:自动使用「代理配置 → 基础代理 URL」,无需单独设置,保证出口 IP 一致。
26
+
27
+ ## 使用方式
28
+
29
+ ### Docker Compose 部署
30
+
31
+ 已集成在项目根目录 `docker-compose.yml` 中。只需在 grok2api 服务的环境变量中设置 `FLARESOLVERR_URL`,并添加 `flaresolverr` 服务即可:
32
+
33
+ ```yaml
34
+ services:
35
+ grok2api:
36
+ environment:
37
+ FLARESOLVERR_URL: http://flaresolverr:8191
38
+
39
+ flaresolverr:
40
+ image: ghcr.io/flaresolverr/flaresolverr:latest
41
+ restart: unless-stopped
42
+ ```
43
+
44
+ ## 注意事项
45
+
46
+ - `cf_clearance` 与请求来源 IP 绑定,FlareSolverr 自动使用代理配置中的基础代理 URL 保证出口 IP 一致
47
+ - 启用自动刷新后,代理配置中的 CF Clearance、浏览器指纹和 User-Agent 由系统自动管理(面板中变灰)
48
+ - 建议刷新间隔不低于 5 分钟,避免触发 Cloudflare 频率限制
49
+ - FlareSolverr 需要约 500MB 内存(内部运行 Chrome)
app/services/cf_refresh/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """cf_refresh - Cloudflare cf_clearance 自动刷新模块"""
2
+
3
+ from .scheduler import start, stop
4
+
5
+ __all__ = ["start", "stop"]
app/services/cf_refresh/config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """配置管理 — 从 app config 的 proxy.* 读取,支持面板修改实时生效"""
2
+
3
+ GROK_URL = "https://grok.com"
4
+
5
+
6
+ def _get(key: str, default=None):
7
+ """从 app config 读取 proxy.* 配置"""
8
+ from app.core.config import get_config
9
+ return get_config(f"proxy.{key}", default)
10
+
11
+
12
+ def get_flaresolverr_url() -> str:
13
+ return _get("flaresolverr_url", "") or ""
14
+
15
+
16
+ def _get_int(key: str, default: int, min_value: int) -> int:
17
+ raw = _get(key, default)
18
+ try:
19
+ value = int(raw)
20
+ except (TypeError, ValueError):
21
+ return max(default, min_value)
22
+ if value < min_value:
23
+ return min_value
24
+ return value
25
+
26
+
27
+ def get_refresh_interval() -> int:
28
+ return _get_int("refresh_interval", 600, 60)
29
+
30
+
31
+ def get_timeout() -> int:
32
+ return _get_int("timeout", 60, 60)
33
+
34
+
35
+ def get_proxy() -> str:
36
+ """使用基础代理 URL,保证出口 IP 一致"""
37
+ return _get("base_proxy_url", "") or ""
38
+
39
+
40
+ def is_enabled() -> bool:
41
+ return bool(_get("enabled", False))
app/services/cf_refresh/scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """定时调度:周期性刷新 cf_clearance(集成到 grok2api 进程内)"""
2
+
3
+ import asyncio
4
+
5
+ from loguru import logger
6
+
7
+ from .config import get_refresh_interval, get_flaresolverr_url, is_enabled
8
+ from .solver import solve_cf_challenge
9
+
10
+ _task: asyncio.Task | None = None
11
+
12
+
13
+ async def _update_app_config(
14
+ cf_cookies: str,
15
+ user_agent: str = "",
16
+ browser: str = "",
17
+ cf_clearance: str = "",
18
+ ) -> bool:
19
+ """直接更新 grok2api 的运行时配置"""
20
+ try:
21
+ from app.core.config import config
22
+
23
+ proxy_update = {"cf_cookies": cf_cookies}
24
+ if cf_clearance:
25
+ proxy_update["cf_clearance"] = cf_clearance
26
+ if user_agent:
27
+ proxy_update["user_agent"] = user_agent
28
+ if browser:
29
+ proxy_update["browser"] = browser
30
+
31
+ await config.update({"proxy": proxy_update})
32
+
33
+ logger.info(f"配置已更新: cf_cookies (长度 {len(cf_cookies)}), 指纹: {browser}")
34
+ if user_agent:
35
+ logger.info(f"配置已更新: user_agent = {user_agent}")
36
+ return True
37
+ except Exception as e:
38
+ logger.error(f"更新配置失败: {e}")
39
+ return False
40
+
41
+
42
+ async def refresh_once() -> bool:
43
+ """执行一次刷新流程"""
44
+ logger.info("=" * 50)
45
+ logger.info("开始刷新 cf_clearance...")
46
+
47
+ result = await solve_cf_challenge()
48
+ if not result:
49
+ logger.error("刷新失败:无法获取 cf_clearance")
50
+ return False
51
+
52
+ success = await _update_app_config(
53
+ cf_cookies=result["cookies"],
54
+ cf_clearance=result.get("cf_clearance", ""),
55
+ user_agent=result.get("user_agent", ""),
56
+ browser=result.get("browser", ""),
57
+ )
58
+
59
+ if success:
60
+ logger.info("刷新完成")
61
+ else:
62
+ logger.error("刷新失败: 更新配置失败")
63
+
64
+ return success
65
+
66
+
67
+ async def _scheduler_loop():
68
+ """后台调度循环"""
69
+ logger.info(
70
+ f"cf_refresh scheduler started (FlareSolverr: {get_flaresolverr_url()}, interval: {get_refresh_interval()}s)"
71
+ )
72
+
73
+ # 周期性刷新(每次循环重新读取配置,支持面板修改实时生效)
74
+ while True:
75
+ if is_enabled():
76
+ await refresh_once()
77
+ else:
78
+ logger.debug("cf_refresh disabled, skip refresh")
79
+ interval = get_refresh_interval()
80
+ await asyncio.sleep(interval)
81
+
82
+
83
+ def start():
84
+ """启动后台刷新任务"""
85
+ global _task
86
+ if _task is not None:
87
+ return
88
+ _task = asyncio.get_event_loop().create_task(_scheduler_loop())
89
+ logger.info("cf_refresh background task started")
90
+
91
+
92
+ def stop():
93
+ """停止后台刷新任务"""
94
+ global _task
95
+ if _task is not None:
96
+ _task.cancel()
97
+ _task = None
98
+ logger.info("cf_refresh background task stopped")
app/services/cf_refresh/solver.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 通过 FlareSolverr 自动获取 cf_clearance
3
+
4
+ FlareSolverr 是一个 Docker 服务,内部运行 Chrome 浏览器,
5
+ 自动处理 Cloudflare 挑战(包括 Turnstile),无需 GUI。
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ from typing import Optional, Dict
11
+ from urllib import request as urllib_request
12
+ from urllib.error import HTTPError, URLError
13
+
14
+ from loguru import logger
15
+
16
+ from .config import GROK_URL, get_timeout, get_proxy, get_flaresolverr_url
17
+
18
+
19
+ def _extract_all_cookies(cookies: list[dict]) -> str:
20
+ """将 FlareSolverr 返回 of cookie 列表转换为字符串格式"""
21
+ return "; ".join([f"{c.get('name')}={c.get('value')}" for c in cookies])
22
+
23
+
24
+ def _extract_cookie_value(cookies: list[dict], name: str) -> str:
25
+ for cookie in cookies:
26
+ if cookie.get("name") == name:
27
+ return cookie.get("value") or ""
28
+ return ""
29
+
30
+
31
+ def _extract_user_agent(solution: dict) -> str:
32
+ """从 FlareSolverr 的 solution 中提取 User-Agent"""
33
+ return solution.get("userAgent", "")
34
+
35
+
36
+ def _extract_browser_profile(user_agent: str) -> str:
37
+ """从 User-Agent 提取 chromeXXX 格式的指纹识别号"""
38
+ import re
39
+ match = re.search(r"Chrome/(\d+)", user_agent)
40
+ if match:
41
+ return f"chrome{match.group(1)}"
42
+ return "chrome120"
43
+
44
+
45
+ async def solve_cf_challenge() -> Optional[Dict[str, str]]:
46
+ """
47
+ 通过 FlareSolverr 访问 grok.com,自动过 CF 挑战,提取 cf_clearance。
48
+
49
+ Returns:
50
+ 成功时返回 {"cookies": "...", "user_agent": "..."},失败返回 None
51
+ """
52
+ flaresolverr_url = get_flaresolverr_url()
53
+ cf_timeout = get_timeout()
54
+ proxy = get_proxy()
55
+
56
+ if not flaresolverr_url:
57
+ logger.error("FlareSolverr 地址未配置,无法刷新 cf_clearance")
58
+ return None
59
+
60
+ url = f"{flaresolverr_url.rstrip('/')}/v1"
61
+
62
+ payload = {
63
+ "cmd": "request.get",
64
+ "url": GROK_URL,
65
+ "maxTimeout": cf_timeout * 1000,
66
+ }
67
+
68
+ if proxy:
69
+ payload["proxy"] = {"url": proxy}
70
+
71
+ body = json.dumps(payload).encode("utf-8")
72
+ headers = {"Content-Type": "application/json"}
73
+
74
+ logger.info(f"正在通过 FlareSolverr 访问 {GROK_URL} ...")
75
+ logger.debug(f"FlareSolverr 地址: {url}")
76
+
77
+ req = urllib_request.Request(url, data=body, method="POST", headers=headers)
78
+
79
+ try:
80
+ def _post():
81
+ with urllib_request.urlopen(req, timeout=cf_timeout + 30) as resp:
82
+ return json.loads(resp.read().decode("utf-8"))
83
+
84
+ result = await asyncio.to_thread(_post)
85
+
86
+ status = result.get("status", "")
87
+ if status != "ok":
88
+ message = result.get("message", "unknown error")
89
+ logger.error(f"FlareSolverr 返回错误: {status} - {message}")
90
+ return None
91
+
92
+ solution = result.get("solution", {})
93
+ cookies = solution.get("cookies", [])
94
+
95
+ if not cookies:
96
+ logger.error("FlareSolverr 成功访问但没有返回 cookies")
97
+ return None
98
+
99
+ cookie_str = _extract_all_cookies(cookies)
100
+ clearance = _extract_cookie_value(cookies, "cf_clearance")
101
+ ua = _extract_user_agent(solution)
102
+ browser = _extract_browser_profile(ua)
103
+ logger.info(f"成功获取 cookies (数量: {len(cookies)}), 指纹: {browser}")
104
+
105
+ return {
106
+ "cookies": cookie_str,
107
+ "cf_clearance": clearance,
108
+ "user_agent": ua,
109
+ "browser": browser,
110
+ }
111
+
112
+ except HTTPError as e:
113
+ body_text = e.read().decode("utf-8", "replace")[:300]
114
+ logger.error(f"FlareSolverr 请求失败: {e.code} - {body_text}")
115
+ return None
116
+ except URLError as e:
117
+ logger.error(f"无法连接 FlareSolverr ({flaresolverr_url}): {e.reason}")
118
+ logger.info("请确认 FlareSolverr 服务已启动: docker run -p 8191:8191 ghcr.io/flaresolverr/flaresolverr:latest")
119
+ return None
120
+ except Exception as e:
121
+ logger.error(f"请求异常: {e}")
122
+ return None
app/services/grok/batch_services/assets.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch assets service.
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Dict, List, Optional
7
+
8
+ from app.core.config import get_config
9
+ from app.core.logger import logger
10
+ from app.services.reverse.assets_list import AssetsListReverse
11
+ from app.services.reverse.assets_delete import AssetsDeleteReverse
12
+ from app.services.reverse.utils.session import ResettableSession
13
+ from app.core.batch import run_batch
14
+
15
+
16
+ class BaseAssetsService:
17
+ """Base assets service."""
18
+
19
+ def __init__(self):
20
+ self._session: Optional[ResettableSession] = None
21
+
22
+ async def _get_session(self) -> ResettableSession:
23
+ if self._session is None:
24
+ browser = get_config("proxy.browser")
25
+ if browser:
26
+ self._session = ResettableSession(impersonate=browser)
27
+ else:
28
+ self._session = ResettableSession()
29
+ return self._session
30
+
31
+ async def close(self):
32
+ if self._session:
33
+ await self._session.close()
34
+ self._session = None
35
+
36
+
37
+ _LIST_SEMAPHORE = None
38
+ _LIST_SEM_VALUE = None
39
+ _DELETE_SEMAPHORE = None
40
+ _DELETE_SEM_VALUE = None
41
+
42
+
43
+ def _get_list_semaphore() -> asyncio.Semaphore:
44
+ value = max(1, int(get_config("asset.list_concurrent")))
45
+ global _LIST_SEMAPHORE, _LIST_SEM_VALUE
46
+ if _LIST_SEMAPHORE is None or value != _LIST_SEM_VALUE:
47
+ _LIST_SEM_VALUE = value
48
+ _LIST_SEMAPHORE = asyncio.Semaphore(value)
49
+ return _LIST_SEMAPHORE
50
+
51
+
52
+ def _get_delete_semaphore() -> asyncio.Semaphore:
53
+ value = max(1, int(get_config("asset.delete_concurrent")))
54
+ global _DELETE_SEMAPHORE, _DELETE_SEM_VALUE
55
+ if _DELETE_SEMAPHORE is None or value != _DELETE_SEM_VALUE:
56
+ _DELETE_SEM_VALUE = value
57
+ _DELETE_SEMAPHORE = asyncio.Semaphore(value)
58
+ return _DELETE_SEMAPHORE
59
+
60
+
61
+ class ListService(BaseAssetsService):
62
+ """Assets list service."""
63
+
64
+ async def list(self, token: str) -> Dict[str, List[str] | int]:
65
+ params = {
66
+ "pageSize": 50,
67
+ "orderBy": "ORDER_BY_LAST_USE_TIME",
68
+ "source": "SOURCE_ANY",
69
+ "isLatest": "true",
70
+ }
71
+ page_token = None
72
+ seen_tokens = set()
73
+ asset_ids: List[str] = []
74
+ session = await self._get_session()
75
+ while True:
76
+ if page_token:
77
+ if page_token in seen_tokens:
78
+ logger.warning("Pagination stopped: repeated page token")
79
+ break
80
+ seen_tokens.add(page_token)
81
+ params["pageToken"] = page_token
82
+ else:
83
+ params.pop("pageToken", None)
84
+
85
+ async with _get_list_semaphore():
86
+ response = await AssetsListReverse.request(
87
+ session,
88
+ token,
89
+ params,
90
+ )
91
+
92
+ result = response.json()
93
+ page_assets = result.get("assets", [])
94
+ if page_assets:
95
+ for asset in page_assets:
96
+ asset_id = asset.get("assetId")
97
+ if asset_id:
98
+ asset_ids.append(asset_id)
99
+
100
+ page_token = result.get("nextPageToken")
101
+ if not page_token:
102
+ break
103
+
104
+ logger.info(f"List success: {len(asset_ids)} files")
105
+ return {"asset_ids": asset_ids, "count": len(asset_ids)}
106
+
107
+ @staticmethod
108
+ async def fetch_assets_details(
109
+ tokens: List[str],
110
+ account_map: dict,
111
+ *,
112
+ include_ok: bool = False,
113
+ on_item=None,
114
+ should_cancel=None,
115
+ ) -> dict:
116
+ """Batch fetch assets details for tokens."""
117
+ account_map = account_map or {}
118
+ shared_service = ListService()
119
+ batch_size = max(1, int(get_config("asset.list_batch_size")))
120
+
121
+ async def _fetch_detail(token: str):
122
+ account = account_map.get(token)
123
+ try:
124
+ result = await shared_service.list(token)
125
+ asset_ids = result.get("asset_ids", [])
126
+ count = result.get("count", len(asset_ids))
127
+ detail = {
128
+ "token": token,
129
+ "token_masked": account["token_masked"] if account else token,
130
+ "count": count,
131
+ "status": "ok",
132
+ "last_asset_clear_at": account["last_asset_clear_at"]
133
+ if account
134
+ else None,
135
+ }
136
+ if include_ok:
137
+ return {"ok": True, "detail": detail, "count": count}
138
+ return {"detail": detail, "count": count}
139
+ except Exception as e:
140
+ detail = {
141
+ "token": token,
142
+ "token_masked": account["token_masked"] if account else token,
143
+ "count": 0,
144
+ "status": f"error: {str(e)}",
145
+ "last_asset_clear_at": account["last_asset_clear_at"]
146
+ if account
147
+ else None,
148
+ }
149
+ if include_ok:
150
+ return {"ok": False, "detail": detail, "count": 0}
151
+ return {"detail": detail, "count": 0}
152
+
153
+ try:
154
+ return await run_batch(
155
+ tokens,
156
+ _fetch_detail,
157
+ batch_size=batch_size,
158
+ on_item=on_item,
159
+ should_cancel=should_cancel,
160
+ )
161
+ finally:
162
+ await shared_service.close()
163
+
164
+
165
+ class DeleteService(BaseAssetsService):
166
+ """Assets delete service."""
167
+
168
+ async def delete(self, token: str, asset_ids: List[str]) -> Dict[str, int]:
169
+ if not asset_ids:
170
+ logger.info("No assets to delete")
171
+ return {"total": 0, "success": 0, "failed": 0, "skipped": True}
172
+
173
+ total = len(asset_ids)
174
+ success = 0
175
+ failed = 0
176
+ session = await self._get_session()
177
+
178
+ async def _delete_one(asset_id: str):
179
+ async with _get_delete_semaphore():
180
+ await AssetsDeleteReverse.request(session, token, asset_id)
181
+
182
+ tasks = [_delete_one(asset_id) for asset_id in asset_ids if asset_id]
183
+ results = await asyncio.gather(*tasks, return_exceptions=True)
184
+ for res in results:
185
+ if isinstance(res, Exception):
186
+ failed += 1
187
+ else:
188
+ success += 1
189
+
190
+ logger.info(f"Delete all: total={total}, success={success}, failed={failed}")
191
+ return {"total": total, "success": success, "failed": failed}
192
+
193
+ @staticmethod
194
+ async def clear_assets(
195
+ tokens: List[str],
196
+ mgr,
197
+ *,
198
+ include_ok: bool = False,
199
+ on_item=None,
200
+ should_cancel=None,
201
+ ) -> dict:
202
+ """Batch clear assets for tokens."""
203
+ delete_service = DeleteService()
204
+ list_service = ListService()
205
+ batch_size = max(1, int(get_config("asset.delete_batch_size")))
206
+
207
+ async def _clear_one(token: str):
208
+ try:
209
+ result = await list_service.list(token)
210
+ asset_ids = result.get("asset_ids", [])
211
+ result = await delete_service.delete(token, asset_ids)
212
+ await mgr.mark_asset_clear(token)
213
+ if include_ok:
214
+ return {"ok": True, "result": result}
215
+ return {"status": "success", "result": result}
216
+ except Exception as e:
217
+ if include_ok:
218
+ return {"ok": False, "error": str(e)}
219
+ return {"status": "error", "error": str(e)}
220
+
221
+ try:
222
+ return await run_batch(
223
+ tokens,
224
+ _clear_one,
225
+ batch_size=batch_size,
226
+ on_item=on_item,
227
+ should_cancel=should_cancel,
228
+ )
229
+ finally:
230
+ await delete_service.close()
231
+ await list_service.close()
232
+
233
+
234
+ __all__ = ["ListService", "DeleteService"]
app/services/grok/batch_services/nsfw.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch NSFW service.
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Callable, Awaitable, Dict, Any, Optional
7
+
8
+ from app.core.logger import logger
9
+ from app.core.config import get_config
10
+ from app.core.exceptions import UpstreamException
11
+ from app.services.reverse.accept_tos import AcceptTosReverse
12
+ from app.services.reverse.nsfw_mgmt import NsfwMgmtReverse
13
+ from app.services.reverse.set_birth import SetBirthReverse
14
+ from app.services.reverse.utils.session import ResettableSession
15
+ from app.core.batch import run_batch
16
+
17
+
18
+ _NSFW_SEMAPHORE = None
19
+ _NSFW_SEM_VALUE = None
20
+
21
+
22
+ def _get_nsfw_semaphore() -> asyncio.Semaphore:
23
+ value = max(1, int(get_config("nsfw.concurrent")))
24
+ global _NSFW_SEMAPHORE, _NSFW_SEM_VALUE
25
+ if _NSFW_SEMAPHORE is None or value != _NSFW_SEM_VALUE:
26
+ _NSFW_SEM_VALUE = value
27
+ _NSFW_SEMAPHORE = asyncio.Semaphore(value)
28
+ return _NSFW_SEMAPHORE
29
+
30
+
31
+ class NSFWService:
32
+ """NSFW 模式服务"""
33
+ @staticmethod
34
+ async def batch(
35
+ tokens: list[str],
36
+ mgr,
37
+ *,
38
+ on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
39
+ should_cancel: Optional[Callable[[], bool]] = None,
40
+ ) -> Dict[str, Dict[str, Any]]:
41
+ """Batch enable NSFW."""
42
+ batch_size = get_config("nsfw.batch_size")
43
+ async def _enable(token: str):
44
+ try:
45
+ browser = get_config("proxy.browser")
46
+ async with ResettableSession(impersonate=browser) as session:
47
+ async def _record_fail(err: UpstreamException, reason: str):
48
+ status = None
49
+ if err.details and "status" in err.details:
50
+ status = err.details["status"]
51
+ else:
52
+ status = getattr(err, "status_code", None)
53
+ if status == 401:
54
+ await mgr.record_fail(token, status, reason)
55
+ return status or 0
56
+
57
+ try:
58
+ async with _get_nsfw_semaphore():
59
+ await AcceptTosReverse.request(session, token)
60
+ except UpstreamException as e:
61
+ status = await _record_fail(e, "tos_auth_failed")
62
+ return {
63
+ "success": False,
64
+ "http_status": status,
65
+ "error": f"Accept ToS failed: {str(e)}",
66
+ }
67
+
68
+ try:
69
+ async with _get_nsfw_semaphore():
70
+ await SetBirthReverse.request(session, token)
71
+ except UpstreamException as e:
72
+ status = await _record_fail(e, "set_birth_auth_failed")
73
+ return {
74
+ "success": False,
75
+ "http_status": status,
76
+ "error": f"Set birth date failed: {str(e)}",
77
+ }
78
+
79
+ try:
80
+ async with _get_nsfw_semaphore():
81
+ grpc_status = await NsfwMgmtReverse.request(session, token)
82
+ success = grpc_status.code in (-1, 0)
83
+ except UpstreamException as e:
84
+ status = await _record_fail(e, "nsfw_mgmt_auth_failed")
85
+ return {
86
+ "success": False,
87
+ "http_status": status,
88
+ "error": f"NSFW enable failed: {str(e)}",
89
+ }
90
+ if success:
91
+ await mgr.add_tag(token, "nsfw")
92
+ return {
93
+ "success": success,
94
+ "http_status": 200,
95
+ "grpc_status": grpc_status.code,
96
+ "grpc_message": grpc_status.message or None,
97
+ "error": None,
98
+ }
99
+ except Exception as e:
100
+ logger.error(f"NSFW enable failed: {e}")
101
+ return {"success": False, "http_status": 0, "error": str(e)[:100]}
102
+
103
+ return await run_batch(
104
+ tokens,
105
+ _enable,
106
+ batch_size=batch_size,
107
+ on_item=on_item,
108
+ should_cancel=should_cancel,
109
+ )
110
+
111
+
112
+ __all__ = ["NSFWService"]
app/services/grok/batch_services/usage.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch usage service.
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Callable, Awaitable, Dict, Any, Optional, List
7
+
8
+ from app.core.logger import logger
9
+ from app.core.config import get_config
10
+ from app.services.reverse.rate_limits import RateLimitsReverse
11
+ from app.services.reverse.utils.session import ResettableSession
12
+ from app.core.batch import run_batch
13
+
14
+ _USAGE_SEMAPHORE = None
15
+ _USAGE_SEM_VALUE = None
16
+
17
+
18
+ def _get_usage_semaphore() -> asyncio.Semaphore:
19
+ value = max(1, int(get_config("usage.concurrent")))
20
+ global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE
21
+ if _USAGE_SEMAPHORE is None or value != _USAGE_SEM_VALUE:
22
+ _USAGE_SEM_VALUE = value
23
+ _USAGE_SEMAPHORE = asyncio.Semaphore(value)
24
+ return _USAGE_SEMAPHORE
25
+
26
+
27
+ class UsageService:
28
+ """用量查询服务"""
29
+
30
+ async def get(self, token: str) -> Dict:
31
+ """
32
+ 获取速率限制信息
33
+
34
+ Args:
35
+ token: 认证 Token
36
+
37
+ Returns:
38
+ 响应数据
39
+
40
+ Raises:
41
+ UpstreamException: 当获取失败且重试耗尽时
42
+ """
43
+ async with _get_usage_semaphore():
44
+ try:
45
+ browser = get_config("proxy.browser")
46
+ if browser:
47
+ session_ctx = ResettableSession(impersonate=browser)
48
+ else:
49
+ session_ctx = ResettableSession()
50
+ async with session_ctx as session:
51
+ response = await RateLimitsReverse.request(session, token)
52
+ data = response.json()
53
+ remaining = data.get("remainingTokens")
54
+ if remaining is None:
55
+ remaining = data.get("remainingQueries")
56
+ if remaining is not None:
57
+ data["remainingTokens"] = remaining
58
+ logger.info(
59
+ f"Usage sync success: remaining={remaining}, token={token[:10]}..."
60
+ )
61
+ return data
62
+
63
+ except Exception:
64
+ # 最后一次失败已经被记录
65
+ raise
66
+
67
+
68
+ @staticmethod
69
+ async def batch(
70
+ tokens: List[str],
71
+ mgr,
72
+ *,
73
+ on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None,
74
+ should_cancel: Optional[Callable[[], bool]] = None,
75
+ ) -> Dict[str, Dict[str, Any]]:
76
+ batch_size = get_config("usage.batch_size")
77
+ async def _refresh_one(t: str):
78
+ return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False)
79
+
80
+ return await run_batch(
81
+ tokens,
82
+ _refresh_one,
83
+ batch_size=batch_size,
84
+ on_item=on_item,
85
+ should_cancel=should_cancel,
86
+ )
87
+
88
+
89
+ __all__ = ["UsageService"]
app/services/grok/defaults.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok 服务默认配置
3
+
4
+ 此文件读取 config.defaults.toml,作为 Grok 服务的默认值来源。
5
+ """
6
+
7
+ from pathlib import Path
8
+ import tomllib
9
+
10
+ from app.core.logger import logger
11
+
12
+ DEFAULTS_FILE = Path(__file__).resolve().parent.parent.parent.parent / "config.defaults.toml"
13
+
14
+ # Grok 服务默认配置(运行时从 config.defaults.toml 读取并缓存)
15
+ GROK_DEFAULTS: dict = {}
16
+
17
+
18
+ def get_grok_defaults():
19
+ """获取 Grok 默认配置"""
20
+ global GROK_DEFAULTS
21
+ if GROK_DEFAULTS:
22
+ return GROK_DEFAULTS
23
+ if not DEFAULTS_FILE.exists():
24
+ logger.warning(f"Defaults file not found: {DEFAULTS_FILE}")
25
+ return GROK_DEFAULTS
26
+ try:
27
+ with DEFAULTS_FILE.open("rb") as f:
28
+ GROK_DEFAULTS = tomllib.load(f)
29
+ except Exception as e:
30
+ logger.warning(f"Failed to load defaults from {DEFAULTS_FILE}: {e}")
31
+ return GROK_DEFAULTS
32
+
33
+
34
+ __all__ = ["GROK_DEFAULTS", "get_grok_defaults"]
app/services/grok/services/chat.py ADDED
@@ -0,0 +1,1115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok Chat 服务
3
+ """
4
+
5
+ import asyncio
6
+ import re
7
+ import uuid
8
+ from typing import Dict, List, Any, AsyncGenerator, AsyncIterable
9
+
10
+ import orjson
11
+ from curl_cffi.requests.errors import RequestsError
12
+
13
+ from app.core.logger import logger
14
+ from app.core.config import get_config
15
+ from app.core.exceptions import (
16
+ AppException,
17
+ ValidationException,
18
+ ErrorType,
19
+ UpstreamException,
20
+ StreamIdleTimeoutError,
21
+ )
22
+ from app.services.grok.services.model import ModelService
23
+ from app.services.grok.utils.upload import UploadService
24
+ from app.services.grok.utils import process as proc_base
25
+ from app.services.grok.utils.retry import pick_token, rate_limited, transient_upstream
26
+ from app.services.reverse.app_chat import AppChatReverse
27
+ from app.services.reverse.utils.session import ResettableSession
28
+ from app.services.grok.utils.stream import wrap_stream_with_usage
29
+ from app.services.grok.utils.tool_call import (
30
+ build_tool_prompt,
31
+ parse_tool_calls,
32
+ parse_tool_call_block,
33
+ format_tool_history,
34
+ )
35
+ from app.services.token import get_token_manager, EffortType
36
+
37
+
38
+ _CHAT_SEMAPHORE = None
39
+ _CHAT_SEM_VALUE = None
40
+
41
+
42
+ def extract_tool_text(raw: str, rollout_id: str = "") -> str:
43
+ if not raw:
44
+ return ""
45
+ name_match = re.search(
46
+ r"<xai:tool_name>(.*?)</xai:tool_name>", raw, flags=re.DOTALL
47
+ )
48
+ args_match = re.search(
49
+ r"<xai:tool_args>(.*?)</xai:tool_args>", raw, flags=re.DOTALL
50
+ )
51
+
52
+ name = name_match.group(1) if name_match else ""
53
+ if name:
54
+ name = re.sub(r"<!\[CDATA\[(.*?)\]\]>", r"\1", name, flags=re.DOTALL).strip()
55
+
56
+ args = args_match.group(1) if args_match else ""
57
+ if args:
58
+ args = re.sub(r"<!\[CDATA\[(.*?)\]\]>", r"\1", args, flags=re.DOTALL).strip()
59
+
60
+ payload = None
61
+ if args:
62
+ try:
63
+ payload = orjson.loads(args)
64
+ except orjson.JSONDecodeError:
65
+ payload = None
66
+
67
+ label = name
68
+ text = args
69
+ prefix = f"[{rollout_id}]" if rollout_id else ""
70
+
71
+ if name == "web_search":
72
+ label = f"{prefix}[WebSearch]"
73
+ if isinstance(payload, dict):
74
+ text = payload.get("query") or payload.get("q") or ""
75
+ elif name == "search_images":
76
+ label = f"{prefix}[SearchImage]"
77
+ if isinstance(payload, dict):
78
+ text = (
79
+ payload.get("image_description")
80
+ or payload.get("description")
81
+ or payload.get("query")
82
+ or ""
83
+ )
84
+ elif name == "chatroom_send":
85
+ label = f"{prefix}[AgentThink]"
86
+ if isinstance(payload, dict):
87
+ text = payload.get("message") or ""
88
+
89
+ if label and text:
90
+ return f"{label} {text}".strip()
91
+ if label:
92
+ return label
93
+ if text:
94
+ return text
95
+ # Fallback: strip tags to keep any raw text.
96
+ return re.sub(r"<[^>]+>", "", raw, flags=re.DOTALL).strip()
97
+
98
+
99
+ def _get_chat_semaphore() -> asyncio.Semaphore:
100
+ global _CHAT_SEMAPHORE, _CHAT_SEM_VALUE
101
+ value = max(1, int(get_config("chat.concurrent")))
102
+ if value != _CHAT_SEM_VALUE:
103
+ _CHAT_SEM_VALUE = value
104
+ _CHAT_SEMAPHORE = asyncio.Semaphore(value)
105
+ return _CHAT_SEMAPHORE
106
+
107
+
108
+ class MessageExtractor:
109
+ """消息内容提取器"""
110
+
111
+ @staticmethod
112
+ def extract(
113
+ messages: List[Dict[str, Any]],
114
+ tools: List[Dict[str, Any]] = None,
115
+ tool_choice: Any = None,
116
+ parallel_tool_calls: bool = True,
117
+ ) -> tuple[str, List[str], List[str]]:
118
+ """从 OpenAI 消息格式提取内容,返回 (text, file_attachments, image_attachments)"""
119
+ # Pre-process: convert tool-related messages to text format
120
+ if tools:
121
+ messages = format_tool_history(messages)
122
+
123
+ texts = []
124
+ file_attachments: List[str] = []
125
+ image_attachments: List[str] = []
126
+ extracted = []
127
+
128
+ for msg in messages:
129
+ role = msg.get("role", "") or "user"
130
+ content = msg.get("content", "")
131
+ parts = []
132
+
133
+ if isinstance(content, str):
134
+ if content.strip():
135
+ parts.append(content)
136
+ elif isinstance(content, dict):
137
+ content = [content]
138
+ for item in content:
139
+ if not isinstance(item, dict):
140
+ continue
141
+ item_type = item.get("type", "")
142
+ if item_type == "text":
143
+ if text := item.get("text", "").strip():
144
+ parts.append(text)
145
+ elif item_type == "image_url":
146
+ image_data = item.get("image_url", {})
147
+ url = image_data.get("url", "")
148
+ if url:
149
+ image_attachments.append(url)
150
+ elif item_type == "input_audio":
151
+ audio_data = item.get("input_audio", {})
152
+ data = audio_data.get("data", "")
153
+ if data:
154
+ file_attachments.append(data)
155
+ elif item_type == "file":
156
+ file_data = item.get("file", {})
157
+ raw = file_data.get("file_data", "")
158
+ if raw:
159
+ file_attachments.append(raw)
160
+ elif isinstance(content, list):
161
+ for item in content:
162
+ if not isinstance(item, dict):
163
+ continue
164
+ item_type = item.get("type", "")
165
+
166
+ if item_type == "text":
167
+ if text := item.get("text", "").strip():
168
+ parts.append(text)
169
+
170
+ elif item_type == "image_url":
171
+ image_data = item.get("image_url", {})
172
+ url = image_data.get("url", "")
173
+ if url:
174
+ image_attachments.append(url)
175
+
176
+ elif item_type == "input_audio":
177
+ audio_data = item.get("input_audio", {})
178
+ data = audio_data.get("data", "")
179
+ if data:
180
+ file_attachments.append(data)
181
+
182
+ elif item_type == "file":
183
+ file_data = item.get("file", {})
184
+ raw = file_data.get("file_data", "")
185
+ if raw:
186
+ file_attachments.append(raw)
187
+
188
+ # 保留工具调用轨迹,避免部分客户端在多轮工具会话中丢失上下文顺序
189
+ tool_calls = msg.get("tool_calls")
190
+ if role == "assistant" and not parts and isinstance(tool_calls, list):
191
+ for call in tool_calls:
192
+ if not isinstance(call, dict):
193
+ continue
194
+ fn = call.get("function", {})
195
+ if not isinstance(fn, dict):
196
+ fn = {}
197
+ name = fn.get("name") or call.get("name") or "tool"
198
+ arguments = fn.get("arguments", "")
199
+ if isinstance(arguments, (dict, list)):
200
+ try:
201
+ arguments = orjson.dumps(arguments).decode()
202
+ except Exception:
203
+ arguments = str(arguments)
204
+ if not isinstance(arguments, str):
205
+ arguments = str(arguments)
206
+ arguments = arguments.strip()
207
+ parts.append(
208
+ f"[tool_call] {name} {arguments}".strip()
209
+ )
210
+
211
+ if parts:
212
+ role_label = role
213
+ if role == "tool":
214
+ name = msg.get("name")
215
+ call_id = msg.get("tool_call_id")
216
+ if isinstance(name, str) and name.strip():
217
+ role_label = f"tool[{name.strip()}]"
218
+ if isinstance(call_id, str) and call_id.strip():
219
+ role_label = f"{role_label}#{call_id.strip()}"
220
+ extracted.append({"role": role_label, "text": "\n".join(parts)})
221
+
222
+ # 找到最后一条 user 消息
223
+ last_user_index = next(
224
+ (
225
+ i
226
+ for i in range(len(extracted) - 1, -1, -1)
227
+ if extracted[i]["role"] == "user"
228
+ ),
229
+ None,
230
+ )
231
+
232
+ for i, item in enumerate(extracted):
233
+ role = item["role"] or "user"
234
+ text = item["text"]
235
+ texts.append(text if i == last_user_index else f"{role}: {text}")
236
+
237
+ combined = "\n\n".join(texts)
238
+
239
+ # If there are attachments but no text, inject a fallback prompt.
240
+ if (not combined.strip()) and (file_attachments or image_attachments):
241
+ combined = "Refer to the following content:"
242
+
243
+ # Prepend tool system prompt if tools are provided
244
+ if tools:
245
+ tool_prompt = build_tool_prompt(tools, tool_choice, parallel_tool_calls)
246
+ if tool_prompt:
247
+ combined = f"{tool_prompt}\n\n{combined}"
248
+
249
+ return combined, file_attachments, image_attachments
250
+
251
+
252
+ class GrokChatService:
253
+ """Grok API 调用服务"""
254
+
255
+ async def chat(
256
+ self,
257
+ token: str,
258
+ message: str,
259
+ model: str,
260
+ mode: str = None,
261
+ stream: bool = None,
262
+ file_attachments: List[str] = None,
263
+ tool_overrides: Dict[str, Any] = None,
264
+ model_config_override: Dict[str, Any] = None,
265
+ ):
266
+ """发送聊天请求"""
267
+ if stream is None:
268
+ stream = get_config("app.stream")
269
+
270
+ logger.debug(
271
+ f"Chat request: model={model}, mode={mode}, stream={stream}, attachments={len(file_attachments or [])}"
272
+ )
273
+
274
+ browser = get_config("proxy.browser")
275
+ semaphore = _get_chat_semaphore()
276
+ await semaphore.acquire()
277
+ session = ResettableSession(impersonate=browser)
278
+ try:
279
+ stream_response = await AppChatReverse.request(
280
+ session,
281
+ token,
282
+ message=message,
283
+ model=model,
284
+ mode=mode,
285
+ file_attachments=file_attachments,
286
+ tool_overrides=tool_overrides,
287
+ model_config_override=model_config_override,
288
+ )
289
+ logger.info(f"Chat connected: model={model}, stream={stream}")
290
+ except Exception:
291
+ try:
292
+ await session.close()
293
+ except Exception:
294
+ pass
295
+ semaphore.release()
296
+ raise
297
+
298
+ async def _stream():
299
+ try:
300
+ async for line in stream_response:
301
+ yield line
302
+ finally:
303
+ semaphore.release()
304
+
305
+ return _stream()
306
+
307
+ async def chat_openai(
308
+ self,
309
+ token: str,
310
+ model: str,
311
+ messages: List[Dict[str, Any]],
312
+ stream: bool = None,
313
+ reasoning_effort: str | None = None,
314
+ temperature: float = 0.8,
315
+ top_p: float = 0.95,
316
+ tools: List[Dict[str, Any]] = None,
317
+ tool_choice: Any = None,
318
+ parallel_tool_calls: bool = True,
319
+ ):
320
+ """OpenAI 兼容接口"""
321
+ model_info = ModelService.get(model)
322
+ if not model_info:
323
+ raise ValidationException(f"Unknown model: {model}")
324
+
325
+ grok_model = model_info.grok_model
326
+ mode = model_info.model_mode
327
+ # 提取消息和附件
328
+ message, file_attachments, image_attachments = MessageExtractor.extract(
329
+ messages, tools=tools, tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls
330
+ )
331
+ logger.debug(
332
+ "Extracted message length=%s, files=%s, images=%s",
333
+ len(message),
334
+ len(file_attachments),
335
+ len(image_attachments),
336
+ )
337
+
338
+ # 上传附件
339
+ file_ids: List[str] = []
340
+ image_ids: List[str] = []
341
+ if file_attachments or image_attachments:
342
+ upload_service = UploadService()
343
+ try:
344
+ for attach_data in file_attachments:
345
+ file_id, _ = await upload_service.upload_file(attach_data, token)
346
+ file_ids.append(file_id)
347
+ logger.debug(f"Attachment uploaded: type=file, file_id={file_id}")
348
+ for attach_data in image_attachments:
349
+ file_id, _ = await upload_service.upload_file(attach_data, token)
350
+ image_ids.append(file_id)
351
+ logger.debug(f"Attachment uploaded: type=image, file_id={file_id}")
352
+ finally:
353
+ await upload_service.close()
354
+
355
+ all_attachments = file_ids + image_ids
356
+ stream = stream if stream is not None else get_config("app.stream")
357
+
358
+ model_config_override = {
359
+ "temperature": temperature,
360
+ "topP": top_p,
361
+ }
362
+ if reasoning_effort is not None:
363
+ model_config_override["reasoningEffort"] = reasoning_effort
364
+
365
+ response = await self.chat(
366
+ token,
367
+ message,
368
+ grok_model,
369
+ mode,
370
+ stream,
371
+ file_attachments=all_attachments,
372
+ tool_overrides=None,
373
+ model_config_override=model_config_override,
374
+ )
375
+
376
+ return response, stream, model
377
+
378
+
379
+ class ChatService:
380
+ """Chat 业务服务"""
381
+
382
+ @staticmethod
383
+ async def completions(
384
+ model: str,
385
+ messages: List[Dict[str, Any]],
386
+ stream: bool = None,
387
+ reasoning_effort: str | None = None,
388
+ temperature: float = 0.8,
389
+ top_p: float = 0.95,
390
+ tools: List[Dict[str, Any]] = None,
391
+ tool_choice: Any = None,
392
+ parallel_tool_calls: bool = True,
393
+ ):
394
+ """Chat Completions 入口"""
395
+ # 获取 token
396
+ token_mgr = await get_token_manager()
397
+ await token_mgr.reload_if_stale()
398
+
399
+ # 解析参数
400
+ if reasoning_effort is None:
401
+ show_think = get_config("app.thinking")
402
+ else:
403
+ show_think = reasoning_effort != "none"
404
+ is_stream = stream if stream is not None else get_config("app.stream")
405
+
406
+ # 跨 Token 重试循环
407
+ tried_tokens = set()
408
+ max_token_retries = int(get_config("retry.max_retry") or 3)
409
+ last_error = None
410
+
411
+ for attempt in range(max_token_retries):
412
+ # 选择 token
413
+ token = await pick_token(token_mgr, model, tried_tokens)
414
+ if not token:
415
+ if last_error:
416
+ raise last_error
417
+ raise AppException(
418
+ message="No available tokens. Please try again later.",
419
+ error_type=ErrorType.RATE_LIMIT.value,
420
+ code="rate_limit_exceeded",
421
+ status_code=429,
422
+ )
423
+
424
+ tried_tokens.add(token)
425
+
426
+ try:
427
+ # 请求 Grok
428
+ service = GrokChatService()
429
+ response, _, model_name = await service.chat_openai(
430
+ token,
431
+ model,
432
+ messages,
433
+ stream=is_stream,
434
+ reasoning_effort=reasoning_effort,
435
+ temperature=temperature,
436
+ top_p=top_p,
437
+ tools=tools,
438
+ tool_choice=tool_choice,
439
+ parallel_tool_calls=parallel_tool_calls,
440
+ )
441
+
442
+ # 处理响应
443
+ if is_stream:
444
+ logger.debug(f"Processing stream response: model={model}")
445
+ processor = StreamProcessor(model_name, token, show_think, tools=tools, tool_choice=tool_choice)
446
+ return wrap_stream_with_usage(
447
+ processor.process(response), token_mgr, token, model
448
+ )
449
+
450
+ # 非流式
451
+ logger.debug(f"Processing non-stream response: model={model}")
452
+ result = await CollectProcessor(model_name, token, tools=tools, tool_choice=tool_choice).process(response)
453
+ try:
454
+ model_info = ModelService.get(model)
455
+ effort = (
456
+ EffortType.HIGH
457
+ if (model_info and model_info.cost.value == "high")
458
+ else EffortType.LOW
459
+ )
460
+ await token_mgr.consume(token, effort)
461
+ logger.info(f"Chat completed: model={model}, effort={effort.value}")
462
+ except Exception as e:
463
+ logger.warning(f"Failed to record usage: {e}")
464
+ return result
465
+
466
+ except UpstreamException as e:
467
+ last_error = e
468
+
469
+ if rate_limited(e):
470
+ # 配额不足,标记 token 为 cooling 并换 token 重试
471
+ await token_mgr.mark_rate_limited(token)
472
+ logger.warning(
473
+ f"Token {token[:10]}... rate limited (429), "
474
+ f"trying next token (attempt {attempt + 1}/{max_token_retries})"
475
+ )
476
+ continue
477
+
478
+ if transient_upstream(e):
479
+ has_alternative_token = False
480
+ for pool_name in ModelService.pool_candidates_for_model(model):
481
+ if token_mgr.get_token(pool_name, exclude=tried_tokens):
482
+ has_alternative_token = True
483
+ break
484
+ if not has_alternative_token:
485
+ raise
486
+ logger.warning(
487
+ f"Transient upstream error for token {token[:10]}..., "
488
+ f"trying next token (attempt {attempt + 1}/{max_token_retries}): {e}"
489
+ )
490
+ continue
491
+
492
+ # 非 429 错误,不换 token,直接抛出
493
+ raise
494
+
495
+ # 所有 token 都 429,抛出最后的错误
496
+ if last_error:
497
+ raise last_error
498
+ raise AppException(
499
+ message="No available tokens. Please try again later.",
500
+ error_type=ErrorType.RATE_LIMIT.value,
501
+ code="rate_limit_exceeded",
502
+ status_code=429,
503
+ )
504
+
505
+
506
+ class StreamProcessor(proc_base.BaseProcessor):
507
+ """Stream response processor."""
508
+
509
+ def __init__(self, model: str, token: str = "", show_think: bool = None, tools: List[Dict[str, Any]] = None, tool_choice: Any = None):
510
+ super().__init__(model, token)
511
+ self.response_id: str = None
512
+ self.fingerprint: str = ""
513
+ self.rollout_id: str = ""
514
+ self.think_opened: bool = False
515
+ self.image_think_active: bool = False
516
+ self.role_sent: bool = False
517
+ self.filter_tags = get_config("app.filter_tags")
518
+ self.tool_usage_enabled = (
519
+ "xai:tool_usage_card" in (self.filter_tags or [])
520
+ )
521
+ self._tool_usage_opened = False
522
+ self._tool_usage_buffer = ""
523
+
524
+ self.show_think = bool(show_think)
525
+ self.tools = tools
526
+ self.tool_choice = tool_choice
527
+ self._tool_stream_enabled = bool(tools) and tool_choice != "none"
528
+ self._tool_state = "text"
529
+ self._tool_buffer = ""
530
+ self._tool_partial = ""
531
+ self._tool_calls_seen = False
532
+ self._tool_call_index = 0
533
+
534
+ def _with_tool_index(self, tool_call: Any) -> Any:
535
+ if not isinstance(tool_call, dict):
536
+ return tool_call
537
+ if tool_call.get("index") is None:
538
+ tool_call = dict(tool_call)
539
+ tool_call["index"] = self._tool_call_index
540
+ self._tool_call_index += 1
541
+ return tool_call
542
+
543
+ def _filter_tool_card(self, token: str) -> str:
544
+ if not token or not self.tool_usage_enabled:
545
+ return token
546
+
547
+ output_parts: list[str] = []
548
+ rest = token
549
+ start_tag = "<xai:tool_usage_card"
550
+ end_tag = "</xai:tool_usage_card>"
551
+
552
+ while rest:
553
+ if self._tool_usage_opened:
554
+ end_idx = rest.find(end_tag)
555
+ if end_idx == -1:
556
+ self._tool_usage_buffer += rest
557
+ return "".join(output_parts)
558
+ end_pos = end_idx + len(end_tag)
559
+ self._tool_usage_buffer += rest[:end_pos]
560
+ line = extract_tool_text(self._tool_usage_buffer, self.rollout_id)
561
+ if line:
562
+ if output_parts and not output_parts[-1].endswith("\n"):
563
+ output_parts[-1] += "\n"
564
+ output_parts.append(f"{line}\n")
565
+ self._tool_usage_buffer = ""
566
+ self._tool_usage_opened = False
567
+ rest = rest[end_pos:]
568
+ continue
569
+
570
+ start_idx = rest.find(start_tag)
571
+ if start_idx == -1:
572
+ output_parts.append(rest)
573
+ break
574
+
575
+ if start_idx > 0:
576
+ output_parts.append(rest[:start_idx])
577
+
578
+ end_idx = rest.find(end_tag, start_idx)
579
+ if end_idx == -1:
580
+ self._tool_usage_opened = True
581
+ self._tool_usage_buffer = rest[start_idx:]
582
+ break
583
+
584
+ end_pos = end_idx + len(end_tag)
585
+ raw_card = rest[start_idx:end_pos]
586
+ line = extract_tool_text(raw_card, self.rollout_id)
587
+ if line:
588
+ if output_parts and not output_parts[-1].endswith("\n"):
589
+ output_parts[-1] += "\n"
590
+ output_parts.append(f"{line}\n")
591
+ rest = rest[end_pos:]
592
+
593
+ return "".join(output_parts)
594
+
595
+ def _filter_token(self, token: str) -> str:
596
+ """Filter special tags in current token only."""
597
+ if not token:
598
+ return token
599
+
600
+ if self.tool_usage_enabled:
601
+ token = self._filter_tool_card(token)
602
+ if not token:
603
+ return ""
604
+
605
+ if not self.filter_tags:
606
+ return token
607
+
608
+ for tag in self.filter_tags:
609
+ if tag == "xai:tool_usage_card":
610
+ continue
611
+ if f"<{tag}" in token or f"</{tag}" in token:
612
+ return ""
613
+
614
+ return token
615
+
616
+ def _suffix_prefix(self, text: str, tag: str) -> int:
617
+ if not text or not tag:
618
+ return 0
619
+ max_keep = min(len(text), len(tag) - 1)
620
+ for keep in range(max_keep, 0, -1):
621
+ if text.endswith(tag[:keep]):
622
+ return keep
623
+ return 0
624
+
625
+ def _handle_tool_stream(self, chunk: str) -> list[tuple[str, Any]]:
626
+ events: list[tuple[str, Any]] = []
627
+ if not chunk:
628
+ return events
629
+
630
+ start_tag = "<tool_call>"
631
+ end_tag = "</tool_call>"
632
+ data = f"{self._tool_partial}{chunk}"
633
+ self._tool_partial = ""
634
+
635
+ while data:
636
+ if self._tool_state == "text":
637
+ start_idx = data.find(start_tag)
638
+ if start_idx == -1:
639
+ keep = self._suffix_prefix(data, start_tag)
640
+ emit = data[:-keep] if keep else data
641
+ if emit:
642
+ events.append(("text", emit))
643
+ self._tool_partial = data[-keep:] if keep else ""
644
+ break
645
+
646
+ before = data[:start_idx]
647
+ if before:
648
+ events.append(("text", before))
649
+ data = data[start_idx + len(start_tag) :]
650
+ self._tool_state = "tool"
651
+ continue
652
+
653
+ end_idx = data.find(end_tag)
654
+ if end_idx == -1:
655
+ keep = self._suffix_prefix(data, end_tag)
656
+ append = data[:-keep] if keep else data
657
+ if append:
658
+ self._tool_buffer += append
659
+ self._tool_partial = data[-keep:] if keep else ""
660
+ break
661
+
662
+ self._tool_buffer += data[:end_idx]
663
+ data = data[end_idx + len(end_tag) :]
664
+ tool_call = parse_tool_call_block(self._tool_buffer, self.tools)
665
+ if tool_call:
666
+ events.append(("tool", self._with_tool_index(tool_call)))
667
+ self._tool_calls_seen = True
668
+ self._tool_buffer = ""
669
+ self._tool_state = "text"
670
+
671
+ return events
672
+
673
+ def _flush_tool_stream(self) -> list[tuple[str, Any]]:
674
+ events: list[tuple[str, Any]] = []
675
+ if self._tool_state == "text":
676
+ if self._tool_partial:
677
+ events.append(("text", self._tool_partial))
678
+ self._tool_partial = ""
679
+ return events
680
+
681
+ raw = f"{self._tool_buffer}{self._tool_partial}"
682
+ tool_call = parse_tool_call_block(raw, self.tools)
683
+ if tool_call:
684
+ events.append(("tool", self._with_tool_index(tool_call)))
685
+ self._tool_calls_seen = True
686
+ elif raw:
687
+ events.append(("text", f"<tool_call>{raw}"))
688
+ self._tool_buffer = ""
689
+ self._tool_partial = ""
690
+ self._tool_state = "text"
691
+ return events
692
+
693
+ def _sse(self, content: str = "", role: str = None, finish: str = None, tool_calls: list = None) -> str:
694
+ """Build SSE response."""
695
+ delta = {}
696
+ if role:
697
+ delta["role"] = role
698
+ delta["content"] = ""
699
+ elif tool_calls is not None:
700
+ delta["tool_calls"] = tool_calls
701
+ elif content:
702
+ delta["content"] = content
703
+
704
+ chunk = {
705
+ "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
706
+ "object": "chat.completion.chunk",
707
+ "created": self.created,
708
+ "model": self.model,
709
+ "system_fingerprint": self.fingerprint,
710
+ "choices": [
711
+ {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}
712
+ ],
713
+ }
714
+ return f"data: {orjson.dumps(chunk).decode()}\n\n"
715
+
716
+ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
717
+ """Process stream response.
718
+
719
+ Args:
720
+ response: AsyncIterable[bytes], async iterable of bytes
721
+
722
+ Returns:
723
+ AsyncGenerator[str, None], async generator of strings
724
+ """
725
+ idle_timeout = get_config("chat.stream_timeout")
726
+
727
+ try:
728
+ async for line in proc_base._with_idle_timeout(
729
+ response, idle_timeout, self.model
730
+ ):
731
+ line = proc_base._normalize_line(line)
732
+ if not line:
733
+ continue
734
+ try:
735
+ data = orjson.loads(line)
736
+ except orjson.JSONDecodeError:
737
+ continue
738
+
739
+ resp = data.get("result", {}).get("response", {})
740
+ is_thinking = bool(resp.get("isThinking"))
741
+ # isThinking controls <think> tagging
742
+ # when absent, treat as False
743
+
744
+ if (llm := resp.get("llmInfo")) and not self.fingerprint:
745
+ self.fingerprint = llm.get("modelHash", "")
746
+ if rid := resp.get("responseId"):
747
+ self.response_id = rid
748
+ if rid := resp.get("rolloutId"):
749
+ self.rollout_id = str(rid)
750
+
751
+ if not self.role_sent:
752
+ yield self._sse(role="assistant")
753
+ self.role_sent = True
754
+
755
+ if img := resp.get("streamingImageGenerationResponse"):
756
+ if not self.show_think:
757
+ continue
758
+ self.image_think_active = True
759
+ if not self.think_opened:
760
+ yield self._sse("<think>\n")
761
+ self.think_opened = True
762
+ idx = img.get("imageIndex", 0) + 1
763
+ progress = img.get("progress", 0)
764
+ yield self._sse(
765
+ f"正在生成第{idx}张图片中,当前进度{progress}%\n"
766
+ )
767
+ continue
768
+
769
+ if mr := resp.get("modelResponse"):
770
+ if self.image_think_active and self.think_opened:
771
+ yield self._sse("\n</think>\n")
772
+ self.think_opened = False
773
+ self.image_think_active = False
774
+ for url in proc_base._collect_images(mr):
775
+ parts = url.split("/")
776
+ img_id = parts[-2] if len(parts) >= 2 else "image"
777
+ dl_service = self._get_dl()
778
+ rendered = await dl_service.render_image(
779
+ url, self.token, img_id
780
+ )
781
+ yield self._sse(f"{rendered}\n")
782
+
783
+ if (
784
+ (meta := mr.get("metadata", {}))
785
+ .get("llm_info", {})
786
+ .get("modelHash")
787
+ ):
788
+ self.fingerprint = meta["llm_info"]["modelHash"]
789
+ continue
790
+
791
+ if card := resp.get("cardAttachment"):
792
+ json_data = card.get("jsonData")
793
+ if isinstance(json_data, str) and json_data.strip():
794
+ try:
795
+ card_data = orjson.loads(json_data)
796
+ except orjson.JSONDecodeError:
797
+ card_data = None
798
+ if isinstance(card_data, dict):
799
+ image = card_data.get("image") or {}
800
+ original = image.get("original")
801
+ title = image.get("title") or ""
802
+ if original:
803
+ title_safe = title.replace("\n", " ").strip()
804
+ if title_safe:
805
+ yield self._sse(f"![{title_safe}]({original})\n")
806
+ else:
807
+ yield self._sse(f"![image]({original})\n")
808
+ continue
809
+
810
+ if (token := resp.get("token")) is not None:
811
+ if not token:
812
+ continue
813
+ filtered = self._filter_token(token)
814
+ if not filtered:
815
+ continue
816
+ in_think = is_thinking or self.image_think_active
817
+ if in_think:
818
+ if not self.show_think:
819
+ continue
820
+ if not self.think_opened:
821
+ yield self._sse("<think>\n")
822
+ self.think_opened = True
823
+ else:
824
+ if self.think_opened:
825
+ yield self._sse("\n</think>\n")
826
+ self.think_opened = False
827
+
828
+ if in_think:
829
+ yield self._sse(filtered)
830
+ continue
831
+
832
+ if self._tool_stream_enabled:
833
+ for kind, payload in self._handle_tool_stream(filtered):
834
+ if kind == "text":
835
+ yield self._sse(payload)
836
+ elif kind == "tool":
837
+ yield self._sse(tool_calls=[payload])
838
+ continue
839
+
840
+ yield self._sse(filtered)
841
+
842
+ if self.think_opened:
843
+ yield self._sse("</think>\n")
844
+
845
+ if self._tool_stream_enabled:
846
+ for kind, payload in self._flush_tool_stream():
847
+ if kind == "text":
848
+ yield self._sse(payload)
849
+ elif kind == "tool":
850
+ yield self._sse(tool_calls=[payload])
851
+ finish_reason = "tool_calls" if self._tool_calls_seen else "stop"
852
+ yield self._sse(finish=finish_reason)
853
+ else:
854
+ yield self._sse(finish="stop")
855
+
856
+ yield "data: [DONE]\n\n"
857
+ except asyncio.CancelledError:
858
+ logger.debug("Stream cancelled by client", extra={"model": self.model})
859
+ except StreamIdleTimeoutError as e:
860
+ raise UpstreamException(
861
+ message=f"Stream idle timeout after {e.idle_seconds}s",
862
+ status_code=504,
863
+ details={
864
+ "error": str(e),
865
+ "type": "stream_idle_timeout",
866
+ "idle_seconds": e.idle_seconds,
867
+ },
868
+ )
869
+ except RequestsError as e:
870
+ if proc_base._is_http2_error(e):
871
+ logger.warning(f"HTTP/2 stream error: {e}", extra={"model": self.model})
872
+ raise UpstreamException(
873
+ message="Upstream connection closed unexpectedly",
874
+ status_code=502,
875
+ details={"error": str(e), "type": "http2_stream_error"},
876
+ )
877
+ logger.error(f"Stream request error: {e}", extra={"model": self.model})
878
+ raise UpstreamException(
879
+ message=f"Upstream request failed: {e}",
880
+ status_code=502,
881
+ details={"error": str(e)},
882
+ )
883
+ except Exception as e:
884
+ logger.error(
885
+ f"Stream processing error: {e}",
886
+ extra={"model": self.model, "error_type": type(e).__name__},
887
+ )
888
+ raise
889
+ finally:
890
+ await self.close()
891
+
892
+
893
+ class CollectProcessor(proc_base.BaseProcessor):
894
+ """Non-stream response processor."""
895
+
896
+ def __init__(self, model: str, token: str = "", tools: List[Dict[str, Any]] = None, tool_choice: Any = None):
897
+ super().__init__(model, token)
898
+ self.filter_tags = get_config("app.filter_tags")
899
+ self.tools = tools
900
+ self.tool_choice = tool_choice
901
+
902
+ def _filter_content(self, content: str) -> str:
903
+ """Filter special tags in content."""
904
+ if not content or not self.filter_tags:
905
+ return content
906
+
907
+ result = content
908
+ if "xai:tool_usage_card" in self.filter_tags:
909
+ rollout_id = ""
910
+ rollout_match = re.search(
911
+ r"<rolloutId>(.*?)</rolloutId>", result, flags=re.DOTALL
912
+ )
913
+ if rollout_match:
914
+ rollout_id = rollout_match.group(1).strip()
915
+
916
+ result = re.sub(
917
+ r"<xai:tool_usage_card[^>]*>.*?</xai:tool_usage_card>",
918
+ lambda match: (
919
+ f"{extract_tool_text(match.group(0), rollout_id)}\n"
920
+ if extract_tool_text(match.group(0), rollout_id)
921
+ else ""
922
+ ),
923
+ result,
924
+ flags=re.DOTALL,
925
+ )
926
+
927
+ for tag in self.filter_tags:
928
+ if tag == "xai:tool_usage_card":
929
+ continue
930
+ pattern = rf"<{re.escape(tag)}[^>]*>.*?</{re.escape(tag)}>|<{re.escape(tag)}[^>]*/>"
931
+ result = re.sub(pattern, "", result, flags=re.DOTALL)
932
+
933
+ return result
934
+
935
+ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
936
+ """Process and collect full response."""
937
+ response_id = ""
938
+ fingerprint = ""
939
+ content = ""
940
+ idle_timeout = get_config("chat.stream_timeout")
941
+
942
+ try:
943
+ async for line in proc_base._with_idle_timeout(
944
+ response, idle_timeout, self.model
945
+ ):
946
+ line = proc_base._normalize_line(line)
947
+ if not line:
948
+ continue
949
+ try:
950
+ data = orjson.loads(line)
951
+ except orjson.JSONDecodeError:
952
+ continue
953
+
954
+ resp = data.get("result", {}).get("response", {})
955
+
956
+ if (llm := resp.get("llmInfo")) and not fingerprint:
957
+ fingerprint = llm.get("modelHash", "")
958
+
959
+ if mr := resp.get("modelResponse"):
960
+ response_id = mr.get("responseId", "")
961
+ content = mr.get("message", "")
962
+
963
+ card_map: dict[str, tuple[str, str]] = {}
964
+ for raw in mr.get("cardAttachmentsJson") or []:
965
+ if not isinstance(raw, str) or not raw.strip():
966
+ continue
967
+ try:
968
+ card_data = orjson.loads(raw)
969
+ except orjson.JSONDecodeError:
970
+ continue
971
+ if not isinstance(card_data, dict):
972
+ continue
973
+ card_id = card_data.get("id")
974
+ image = card_data.get("image") or {}
975
+ original = image.get("original")
976
+ if not card_id or not original:
977
+ continue
978
+ title = image.get("title") or ""
979
+ card_map[card_id] = (title, original)
980
+
981
+ if content and card_map:
982
+ def _render_card(match: re.Match) -> str:
983
+ card_id = match.group(1)
984
+ item = card_map.get(card_id)
985
+ if not item:
986
+ return ""
987
+ title, original = item
988
+ title_safe = title.replace("\n", " ").strip() or "image"
989
+ prefix = ""
990
+ if match.start() > 0:
991
+ prev = content[match.start() - 1]
992
+ if prev not in ("\n", "\r"):
993
+ prefix = "\n"
994
+ return f"{prefix}![{title_safe}]({original})"
995
+
996
+ content = re.sub(
997
+ r'<grok:render[^>]*card_id="([^"]+)"[^>]*>.*?</grok:render>',
998
+ _render_card,
999
+ content,
1000
+ flags=re.DOTALL,
1001
+ )
1002
+
1003
+ if urls := proc_base._collect_images(mr):
1004
+ content += "\n"
1005
+ for url in urls:
1006
+ parts = url.split("/")
1007
+ img_id = parts[-2] if len(parts) >= 2 else "image"
1008
+ dl_service = self._get_dl()
1009
+ rendered = await dl_service.render_image(
1010
+ url, self.token, img_id
1011
+ )
1012
+ content += f"{rendered}\n"
1013
+
1014
+ if (
1015
+ (meta := mr.get("metadata", {}))
1016
+ .get("llm_info", {})
1017
+ .get("modelHash")
1018
+ ):
1019
+ fingerprint = meta["llm_info"]["modelHash"]
1020
+
1021
+ except asyncio.CancelledError:
1022
+ logger.debug("Collect cancelled by client", extra={"model": self.model})
1023
+ raise
1024
+ except StreamIdleTimeoutError as e:
1025
+ logger.warning(f"Collect idle timeout: {e}", extra={"model": self.model})
1026
+ raise UpstreamException(
1027
+ message=f"Collect stream idle timeout after {e.idle_seconds}s",
1028
+ details={
1029
+ "error": str(e),
1030
+ "type": "stream_idle_timeout",
1031
+ "idle_seconds": e.idle_seconds,
1032
+ "status": 504,
1033
+ },
1034
+ )
1035
+ except RequestsError as e:
1036
+ if proc_base._is_http2_error(e):
1037
+ logger.warning(
1038
+ f"HTTP/2 stream error in collect: {e}", extra={"model": self.model}
1039
+ )
1040
+ raise UpstreamException(
1041
+ message="Upstream connection closed unexpectedly",
1042
+ details={"error": str(e), "type": "http2_stream_error", "status": 502},
1043
+ )
1044
+ logger.error(f"Collect request error: {e}", extra={"model": self.model})
1045
+ raise UpstreamException(
1046
+ message=f"Upstream request failed: {e}",
1047
+ details={"error": str(e), "status": 502},
1048
+ )
1049
+ except Exception as e:
1050
+ logger.error(
1051
+ f"Collect processing error: {e}",
1052
+ extra={"model": self.model, "error_type": type(e).__name__},
1053
+ )
1054
+ raise
1055
+ finally:
1056
+ await self.close()
1057
+
1058
+ content = self._filter_content(content)
1059
+
1060
+ # Parse for tool calls if tools were provided
1061
+ finish_reason = "stop"
1062
+ tool_calls_result = None
1063
+ if self.tools and self.tool_choice != "none":
1064
+ text_content, tool_calls_list = parse_tool_calls(content, self.tools)
1065
+ if tool_calls_list:
1066
+ tool_calls_result = tool_calls_list
1067
+ content = text_content # May be None
1068
+ finish_reason = "tool_calls"
1069
+
1070
+ message_obj = {
1071
+ "role": "assistant",
1072
+ "content": content,
1073
+ "refusal": None,
1074
+ "annotations": [],
1075
+ }
1076
+ if tool_calls_result:
1077
+ message_obj["tool_calls"] = tool_calls_result
1078
+
1079
+ return {
1080
+ "id": response_id,
1081
+ "object": "chat.completion",
1082
+ "created": self.created,
1083
+ "model": self.model,
1084
+ "system_fingerprint": fingerprint,
1085
+ "choices": [
1086
+ {
1087
+ "index": 0,
1088
+ "message": message_obj,
1089
+ "finish_reason": finish_reason,
1090
+ }
1091
+ ],
1092
+ "usage": {
1093
+ "prompt_tokens": 0,
1094
+ "completion_tokens": 0,
1095
+ "total_tokens": 0,
1096
+ "prompt_tokens_details": {
1097
+ "cached_tokens": 0,
1098
+ "text_tokens": 0,
1099
+ "audio_tokens": 0,
1100
+ "image_tokens": 0,
1101
+ },
1102
+ "completion_tokens_details": {
1103
+ "text_tokens": 0,
1104
+ "audio_tokens": 0,
1105
+ "reasoning_tokens": 0,
1106
+ },
1107
+ },
1108
+ }
1109
+
1110
+
1111
+ __all__ = [
1112
+ "GrokChatService",
1113
+ "MessageExtractor",
1114
+ "ChatService",
1115
+ ]
app/services/grok/services/image.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok image services.
3
+ """
4
+
5
+ import asyncio
6
+ import base64
7
+ import math
8
+ import time
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union
12
+
13
+ import orjson
14
+
15
+ from app.core.config import get_config
16
+ from app.core.logger import logger
17
+ from app.core.storage import DATA_DIR
18
+ from app.core.exceptions import AppException, ErrorType, UpstreamException
19
+ from app.services.grok.utils.process import BaseProcessor
20
+ from app.services.grok.utils.retry import pick_token, rate_limited
21
+ from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content
22
+ from app.services.grok.utils.stream import wrap_stream_with_usage
23
+ from app.services.token import EffortType
24
+ from app.services.reverse.ws_imagine import ImagineWebSocketReverse
25
+
26
+
27
+ image_service = ImagineWebSocketReverse()
28
+
29
+
30
+ @dataclass
31
+ class ImageGenerationResult:
32
+ stream: bool
33
+ data: Union[AsyncGenerator[str, None], List[str]]
34
+ usage_override: Optional[dict] = None
35
+
36
+
37
+ class ImageGenerationService:
38
+ """Image generation orchestration service."""
39
+
40
+ async def generate(
41
+ self,
42
+ *,
43
+ token_mgr: Any,
44
+ token: str,
45
+ model_info: Any,
46
+ prompt: str,
47
+ n: int,
48
+ response_format: str,
49
+ size: str,
50
+ aspect_ratio: str,
51
+ stream: bool,
52
+ enable_nsfw: Optional[bool] = None,
53
+ chat_format: bool = False,
54
+ ) -> ImageGenerationResult:
55
+ max_token_retries = int(get_config("retry.max_retry") or 3)
56
+ tried_tokens: set[str] = set()
57
+ last_error: Optional[Exception] = None
58
+
59
+ # resolve nsfw once for routing and upstream
60
+ if enable_nsfw is None:
61
+ enable_nsfw = bool(get_config("image.nsfw"))
62
+ prefer_tags = {"nsfw"} if enable_nsfw else None
63
+
64
+ if stream:
65
+
66
+ async def _stream_retry() -> AsyncGenerator[str, None]:
67
+ nonlocal last_error
68
+ for attempt in range(max_token_retries):
69
+ preferred = token if (attempt == 0 and not prefer_tags) else None
70
+ current_token = await pick_token(
71
+ token_mgr,
72
+ model_info.model_id,
73
+ tried_tokens,
74
+ preferred=preferred,
75
+ prefer_tags=prefer_tags,
76
+ )
77
+ if not current_token:
78
+ if last_error:
79
+ raise last_error
80
+ raise AppException(
81
+ message="No available tokens. Please try again later.",
82
+ error_type=ErrorType.RATE_LIMIT.value,
83
+ code="rate_limit_exceeded",
84
+ status_code=429,
85
+ )
86
+
87
+ tried_tokens.add(current_token)
88
+ yielded = False
89
+ try:
90
+ result = await self._stream_ws(
91
+ token_mgr=token_mgr,
92
+ token=current_token,
93
+ model_info=model_info,
94
+ prompt=prompt,
95
+ n=n,
96
+ response_format=response_format,
97
+ size=size,
98
+ aspect_ratio=aspect_ratio,
99
+ enable_nsfw=enable_nsfw,
100
+ chat_format=chat_format,
101
+ )
102
+ async for chunk in result.data:
103
+ yielded = True
104
+ yield chunk
105
+ return
106
+ except UpstreamException as e:
107
+ last_error = e
108
+ if rate_limited(e):
109
+ if yielded:
110
+ raise
111
+ await token_mgr.mark_rate_limited(current_token)
112
+ logger.warning(
113
+ f"Token {current_token[:10]}... rate limited (429), "
114
+ f"trying next token (attempt {attempt + 1}/{max_token_retries})"
115
+ )
116
+ continue
117
+ raise
118
+
119
+ if last_error:
120
+ raise last_error
121
+ raise AppException(
122
+ message="No available tokens. Please try again later.",
123
+ error_type=ErrorType.RATE_LIMIT.value,
124
+ code="rate_limit_exceeded",
125
+ status_code=429,
126
+ )
127
+
128
+ return ImageGenerationResult(stream=True, data=_stream_retry())
129
+
130
+ for attempt in range(max_token_retries):
131
+ preferred = token if (attempt == 0 and not prefer_tags) else None
132
+ current_token = await pick_token(
133
+ token_mgr,
134
+ model_info.model_id,
135
+ tried_tokens,
136
+ preferred=preferred,
137
+ prefer_tags=prefer_tags,
138
+ )
139
+ if not current_token:
140
+ if last_error:
141
+ raise last_error
142
+ raise AppException(
143
+ message="No available tokens. Please try again later.",
144
+ error_type=ErrorType.RATE_LIMIT.value,
145
+ code="rate_limit_exceeded",
146
+ status_code=429,
147
+ )
148
+
149
+ tried_tokens.add(current_token)
150
+ try:
151
+ return await self._collect_ws(
152
+ token_mgr=token_mgr,
153
+ token=current_token,
154
+ model_info=model_info,
155
+ tried_tokens=tried_tokens,
156
+ prompt=prompt,
157
+ n=n,
158
+ response_format=response_format,
159
+ aspect_ratio=aspect_ratio,
160
+ enable_nsfw=enable_nsfw,
161
+ )
162
+ except UpstreamException as e:
163
+ last_error = e
164
+ if rate_limited(e):
165
+ await token_mgr.mark_rate_limited(current_token)
166
+ logger.warning(
167
+ f"Token {current_token[:10]}... rate limited (429), "
168
+ f"trying next token (attempt {attempt + 1}/{max_token_retries})"
169
+ )
170
+ continue
171
+ raise
172
+
173
+ if last_error:
174
+ raise last_error
175
+ raise AppException(
176
+ message="No available tokens. Please try again later.",
177
+ error_type=ErrorType.RATE_LIMIT.value,
178
+ code="rate_limit_exceeded",
179
+ status_code=429,
180
+ )
181
+
182
+ async def _stream_ws(
183
+ self,
184
+ *,
185
+ token_mgr: Any,
186
+ token: str,
187
+ model_info: Any,
188
+ prompt: str,
189
+ n: int,
190
+ response_format: str,
191
+ size: str,
192
+ aspect_ratio: str,
193
+ enable_nsfw: Optional[bool] = None,
194
+ chat_format: bool = False,
195
+ ) -> ImageGenerationResult:
196
+ if enable_nsfw is None:
197
+ enable_nsfw = bool(get_config("image.nsfw"))
198
+ stream_retries = int(get_config("image.blocked_parallel_attempts") or 5) + 1
199
+ stream_retries = max(1, min(stream_retries, 10))
200
+ upstream = image_service.stream(
201
+ token=token,
202
+ prompt=prompt,
203
+ aspect_ratio=aspect_ratio,
204
+ n=n,
205
+ enable_nsfw=enable_nsfw,
206
+ max_retries=stream_retries,
207
+ )
208
+ processor = ImageWSStreamProcessor(
209
+ model_info.model_id,
210
+ token,
211
+ n=n,
212
+ response_format=response_format,
213
+ size=size,
214
+ chat_format=chat_format,
215
+ )
216
+ stream = wrap_stream_with_usage(
217
+ processor.process(upstream),
218
+ token_mgr,
219
+ token,
220
+ model_info.model_id,
221
+ )
222
+ return ImageGenerationResult(stream=True, data=stream)
223
+
224
+ async def _collect_ws(
225
+ self,
226
+ *,
227
+ token_mgr: Any,
228
+ token: str,
229
+ model_info: Any,
230
+ tried_tokens: set[str],
231
+ prompt: str,
232
+ n: int,
233
+ response_format: str,
234
+ aspect_ratio: str,
235
+ enable_nsfw: Optional[bool] = None,
236
+ ) -> ImageGenerationResult:
237
+ if enable_nsfw is None:
238
+ enable_nsfw = bool(get_config("image.nsfw"))
239
+ all_images: List[str] = []
240
+ seen = set()
241
+ expected_per_call = 6
242
+ calls_needed = max(1, int(math.ceil(n / expected_per_call)))
243
+ calls_needed = min(calls_needed, n)
244
+
245
+ async def _fetch_batch(call_target: int, call_token: str):
246
+ stream_retries = int(get_config("image.blocked_parallel_attempts") or 5) + 1
247
+ stream_retries = max(1, min(stream_retries, 10))
248
+ upstream = image_service.stream(
249
+ token=call_token,
250
+ prompt=prompt,
251
+ aspect_ratio=aspect_ratio,
252
+ n=call_target,
253
+ enable_nsfw=enable_nsfw,
254
+ max_retries=stream_retries,
255
+ )
256
+ processor = ImageWSCollectProcessor(
257
+ model_info.model_id,
258
+ token,
259
+ n=call_target,
260
+ response_format=response_format,
261
+ )
262
+ return await processor.process(upstream)
263
+
264
+ tasks = []
265
+ for i in range(calls_needed):
266
+ remaining = n - (i * expected_per_call)
267
+ call_target = min(expected_per_call, remaining)
268
+ tasks.append(_fetch_batch(call_target, token))
269
+
270
+ results = await asyncio.gather(*tasks, return_exceptions=True)
271
+ for batch in results:
272
+ if isinstance(batch, Exception):
273
+ logger.warning(f"WS batch failed: {batch}")
274
+ continue
275
+ for img in batch:
276
+ if img not in seen:
277
+ seen.add(img)
278
+ all_images.append(img)
279
+ if len(all_images) >= n:
280
+ break
281
+ if len(all_images) >= n:
282
+ break
283
+
284
+ # If upstream likely blocked/reviewed some images, run extra parallel attempts
285
+ # and only keep valid finals selected by ws_imagine classification.
286
+ if len(all_images) < n:
287
+ remaining = n - len(all_images)
288
+ extra_attempts = int(get_config("image.blocked_parallel_attempts") or 5)
289
+ extra_attempts = max(0, min(extra_attempts, 10))
290
+ parallel_enabled = bool(get_config("image.blocked_parallel_enabled", True))
291
+ if extra_attempts > 0:
292
+ logger.warning(
293
+ f"Image finals insufficient ({len(all_images)}/{n}), running "
294
+ f"{extra_attempts} recovery attempts for remaining={remaining}, "
295
+ f"parallel_enabled={parallel_enabled}"
296
+ )
297
+ extra_tasks = []
298
+ if parallel_enabled:
299
+ recovery_tried = set(tried_tokens)
300
+ recovery_tokens: List[str] = []
301
+ for _ in range(extra_attempts):
302
+ recovery_token = await pick_token(
303
+ token_mgr,
304
+ model_info.model_id,
305
+ recovery_tried,
306
+ )
307
+ if not recovery_token:
308
+ break
309
+ recovery_tried.add(recovery_token)
310
+ recovery_tokens.append(recovery_token)
311
+
312
+ if recovery_tokens:
313
+ logger.info(
314
+ f"Recovery using {len(recovery_tokens)} distinct tokens"
315
+ )
316
+ for recovery_token in recovery_tokens:
317
+ extra_tasks.append(
318
+ _fetch_batch(min(expected_per_call, remaining), recovery_token)
319
+ )
320
+ else:
321
+ extra_tasks = [
322
+ _fetch_batch(min(expected_per_call, remaining), token)
323
+ for _ in range(extra_attempts)
324
+ ]
325
+
326
+ if not extra_tasks:
327
+ logger.warning("No tokens available for recovery attempts")
328
+ extra_results = []
329
+ else:
330
+ extra_results = await asyncio.gather(*extra_tasks, return_exceptions=True)
331
+ for batch in extra_results:
332
+ if isinstance(batch, Exception):
333
+ logger.warning(f"WS recovery batch failed: {batch}")
334
+ continue
335
+ for img in batch:
336
+ if img not in seen:
337
+ seen.add(img)
338
+ all_images.append(img)
339
+ if len(all_images) >= n:
340
+ break
341
+ if len(all_images) >= n:
342
+ break
343
+ logger.info(
344
+ f"Image recovery attempts completed: finals={len(all_images)}/{n}, "
345
+ f"attempts={extra_attempts}"
346
+ )
347
+
348
+ if len(all_images) < n:
349
+ logger.error(
350
+ f"Image generation failed after recovery attempts: finals={len(all_images)}/{n}, "
351
+ f"blocked_parallel_attempts={int(get_config('image.blocked_parallel_attempts') or 5)}"
352
+ )
353
+ raise UpstreamException(
354
+ "Image generation blocked or no valid final image",
355
+ details={
356
+ "error_code": "blocked_no_final_image",
357
+ "final_images": len(all_images),
358
+ "requested": n,
359
+ },
360
+ )
361
+
362
+ try:
363
+ await token_mgr.consume(token, self._get_effort(model_info))
364
+ except Exception as e:
365
+ logger.warning(f"Failed to consume token: {e}")
366
+
367
+ selected = self._select_images(all_images, n)
368
+ usage_override = {
369
+ "total_tokens": 0,
370
+ "input_tokens": 0,
371
+ "output_tokens": 0,
372
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
373
+ }
374
+ return ImageGenerationResult(
375
+ stream=False, data=selected, usage_override=usage_override
376
+ )
377
+
378
+ @staticmethod
379
+ def _get_effort(model_info: Any) -> EffortType:
380
+ return (
381
+ EffortType.HIGH
382
+ if (model_info and model_info.cost.value == "high")
383
+ else EffortType.LOW
384
+ )
385
+
386
+ @staticmethod
387
+ def _select_images(images: List[str], n: int) -> List[str]:
388
+ if len(images) >= n:
389
+ return images[:n]
390
+ selected = images.copy()
391
+ while len(selected) < n:
392
+ selected.append("error")
393
+ return selected
394
+
395
+
396
+ class ImageWSBaseProcessor(BaseProcessor):
397
+ """WebSocket image processor base."""
398
+
399
+ def __init__(self, model: str, token: str = "", response_format: str = "b64_json"):
400
+ if response_format == "base64":
401
+ response_format = "b64_json"
402
+ super().__init__(model, token)
403
+ self.response_format = response_format
404
+ if response_format == "url":
405
+ self.response_field = "url"
406
+ elif response_format == "base64":
407
+ self.response_field = "base64"
408
+ else:
409
+ self.response_field = "b64_json"
410
+ self._image_dir: Optional[Path] = None
411
+
412
+ def _ensure_image_dir(self) -> Path:
413
+ if self._image_dir is None:
414
+ base_dir = DATA_DIR / "tmp" / "image"
415
+ base_dir.mkdir(parents=True, exist_ok=True)
416
+ self._image_dir = base_dir
417
+ return self._image_dir
418
+
419
+ def _strip_base64(self, blob: str) -> str:
420
+ if not blob:
421
+ return ""
422
+ if "," in blob and "base64" in blob.split(",", 1)[0]:
423
+ return blob.split(",", 1)[1]
424
+ return blob
425
+
426
+ def _guess_ext(self, blob: str) -> Optional[str]:
427
+ if not blob:
428
+ return None
429
+ header = ""
430
+ data = blob
431
+ if "," in blob and "base64" in blob.split(",", 1)[0]:
432
+ header, data = blob.split(",", 1)
433
+ header = header.lower()
434
+ if "image/png" in header:
435
+ return "png"
436
+ if "image/jpeg" in header or "image/jpg" in header:
437
+ return "jpg"
438
+ if data.startswith("iVBORw0KGgo"):
439
+ return "png"
440
+ if data.startswith("/9j/"):
441
+ return "jpg"
442
+ return None
443
+
444
+ def _filename(self, image_id: str, is_final: bool, ext: Optional[str] = None) -> str:
445
+ if ext:
446
+ ext = ext.lower()
447
+ if ext == "jpeg":
448
+ ext = "jpg"
449
+ if not ext:
450
+ ext = "jpg" if is_final else "png"
451
+ return f"{image_id}.{ext}"
452
+
453
+ def _build_file_url(self, filename: str) -> str:
454
+ app_url = get_config("app.app_url")
455
+ if app_url:
456
+ return f"{app_url.rstrip('/')}/v1/files/image/{filename}"
457
+ return f"/v1/files/image/{filename}"
458
+
459
+ async def _save_blob(
460
+ self, image_id: str, blob: str, is_final: bool, ext: Optional[str] = None
461
+ ) -> str:
462
+ data = self._strip_base64(blob)
463
+ if not data:
464
+ return ""
465
+ image_dir = self._ensure_image_dir()
466
+ ext = ext or self._guess_ext(blob)
467
+ filename = self._filename(image_id, is_final, ext=ext)
468
+ filepath = image_dir / filename
469
+
470
+ def _write_file():
471
+ with open(filepath, "wb") as f:
472
+ f.write(base64.b64decode(data))
473
+
474
+ await asyncio.to_thread(_write_file)
475
+ return self._build_file_url(filename)
476
+
477
+ def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict:
478
+ if not existing:
479
+ return incoming
480
+ if incoming.get("is_final") and not existing.get("is_final"):
481
+ return incoming
482
+ if existing.get("is_final") and not incoming.get("is_final"):
483
+ return existing
484
+ if incoming.get("blob_size", 0) > existing.get("blob_size", 0):
485
+ return incoming
486
+ return existing
487
+
488
+ async def _to_output(self, image_id: str, item: Dict) -> str:
489
+ try:
490
+ if self.response_format == "url":
491
+ return await self._save_blob(
492
+ image_id,
493
+ item.get("blob", ""),
494
+ item.get("is_final", False),
495
+ ext=item.get("ext"),
496
+ )
497
+ return self._strip_base64(item.get("blob", ""))
498
+ except Exception as e:
499
+ logger.warning(f"Image output failed: {e}")
500
+ return ""
501
+
502
+
503
+ class ImageWSStreamProcessor(ImageWSBaseProcessor):
504
+ """WebSocket image stream processor."""
505
+
506
+ def __init__(
507
+ self,
508
+ model: str,
509
+ token: str = "",
510
+ n: int = 1,
511
+ response_format: str = "b64_json",
512
+ size: str = "1024x1024",
513
+ chat_format: bool = False,
514
+ ):
515
+ super().__init__(model, token, response_format)
516
+ self.n = n
517
+ self.size = size
518
+ self.chat_format = chat_format
519
+ self._target_id: Optional[str] = None
520
+ self._index_map: Dict[str, int] = {}
521
+ self._partial_map: Dict[str, int] = {}
522
+ self._initial_sent: set[str] = set()
523
+ self._id_generated: bool = False
524
+ self._response_id: str = ""
525
+
526
+ def _assign_index(self, image_id: str) -> Optional[int]:
527
+ if image_id in self._index_map:
528
+ return self._index_map[image_id]
529
+ if len(self._index_map) >= self.n:
530
+ return None
531
+ self._index_map[image_id] = len(self._index_map)
532
+ return self._index_map[image_id]
533
+
534
+ def _sse(self, event: str, data: dict) -> str:
535
+ return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
536
+
537
+ async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]:
538
+ images: Dict[str, Dict] = {}
539
+ emitted_chat_chunk = False
540
+
541
+ async for item in response:
542
+ if item.get("type") == "error":
543
+ message = item.get("error") or "Upstream error"
544
+ code = item.get("error_code") or "upstream_error"
545
+ status = item.get("status")
546
+ if code == "rate_limit_exceeded" or status == 429:
547
+ raise UpstreamException(message, details=item)
548
+ yield self._sse(
549
+ "error",
550
+ {
551
+ "error": {
552
+ "message": message,
553
+ "type": "server_error",
554
+ "code": code,
555
+ }
556
+ },
557
+ )
558
+ return
559
+ if item.get("type") != "image":
560
+ continue
561
+
562
+ image_id = item.get("image_id")
563
+ if not image_id:
564
+ continue
565
+
566
+ if self.n == 1:
567
+ if self._target_id is None:
568
+ self._target_id = image_id
569
+ index = 0 if image_id == self._target_id else None
570
+ else:
571
+ index = self._assign_index(image_id)
572
+
573
+ images[image_id] = self._pick_best(images.get(image_id), item)
574
+
575
+ if index is None:
576
+ continue
577
+
578
+ if item.get("stage") != "final":
579
+ # Chat Completions image stream should only expose final results.
580
+ if self.chat_format:
581
+ continue
582
+ if image_id not in self._initial_sent:
583
+ self._initial_sent.add(image_id)
584
+ stage = item.get("stage") or "preview"
585
+ if stage == "medium":
586
+ partial_index = 1
587
+ self._partial_map[image_id] = 1
588
+ else:
589
+ partial_index = 0
590
+ self._partial_map[image_id] = 0
591
+ else:
592
+ stage = item.get("stage") or "partial"
593
+ if stage == "preview":
594
+ continue
595
+ partial_index = self._partial_map.get(image_id, 0)
596
+ if stage == "medium":
597
+ partial_index = max(partial_index, 1)
598
+ self._partial_map[image_id] = partial_index
599
+
600
+ if self.response_format == "url":
601
+ partial_id = f"{image_id}-{stage}-{partial_index}"
602
+ partial_out = await self._save_blob(
603
+ partial_id,
604
+ item.get("blob", ""),
605
+ False,
606
+ ext=item.get("ext"),
607
+ )
608
+ else:
609
+ partial_out = self._strip_base64(item.get("blob", ""))
610
+
611
+ if self.chat_format and partial_out:
612
+ partial_out = wrap_image_content(partial_out, self.response_format)
613
+
614
+ if not partial_out:
615
+ continue
616
+
617
+ if self.chat_format:
618
+ # OpenAI ChatCompletion chunk format for partial
619
+ if not self._id_generated:
620
+ self._response_id = make_response_id()
621
+ self._id_generated = True
622
+ emitted_chat_chunk = True
623
+ yield self._sse(
624
+ "chat.completion.chunk",
625
+ make_chat_chunk(
626
+ self._response_id,
627
+ self.model,
628
+ partial_out,
629
+ index=index,
630
+ ),
631
+ )
632
+ else:
633
+ # Original image_generation format
634
+ yield self._sse(
635
+ "image_generation.partial_image",
636
+ {
637
+ "type": "image_generation.partial_image",
638
+ self.response_field: partial_out,
639
+ "created_at": int(time.time()),
640
+ "size": self.size,
641
+ "index": index,
642
+ "partial_image_index": partial_index,
643
+ "image_id": image_id,
644
+ "stage": stage,
645
+ },
646
+ )
647
+
648
+ if self.n == 1:
649
+ target_item = images.get(self._target_id) if self._target_id else None
650
+ if target_item and target_item.get("is_final", False):
651
+ selected = [(self._target_id, target_item)]
652
+ elif images:
653
+ selected = [
654
+ max(
655
+ images.items(),
656
+ key=lambda x: (
657
+ x[1].get("is_final", False),
658
+ x[1].get("blob_size", 0),
659
+ ),
660
+ )
661
+ ]
662
+ else:
663
+ selected = []
664
+ else:
665
+ selected = [
666
+ (image_id, images[image_id])
667
+ for image_id in self._index_map
668
+ if image_id in images and images[image_id].get("is_final", False)
669
+ ]
670
+
671
+ for image_id, item in selected:
672
+ if self.response_format == "url":
673
+ final_image_id = image_id
674
+ # Keep original imagine image name for imagine chat stream output.
675
+ if self.model != "grok-imagine-1.0-fast":
676
+ final_image_id = f"{image_id}-final"
677
+ output = await self._save_blob(
678
+ final_image_id,
679
+ item.get("blob", ""),
680
+ item.get("is_final", False),
681
+ ext=item.get("ext"),
682
+ )
683
+ if self.chat_format and output:
684
+ output = wrap_image_content(output, self.response_format)
685
+ else:
686
+ output = await self._to_output(image_id, item)
687
+ if self.chat_format and output:
688
+ output = wrap_image_content(output, self.response_format)
689
+
690
+ if not output:
691
+ continue
692
+
693
+ if self.n == 1:
694
+ index = 0
695
+ else:
696
+ index = self._index_map.get(image_id, 0)
697
+
698
+ if not self._id_generated:
699
+ self._response_id = make_response_id()
700
+ self._id_generated = True
701
+
702
+ if self.chat_format:
703
+ # OpenAI ChatCompletion chunk format
704
+ emitted_chat_chunk = True
705
+ yield self._sse(
706
+ "chat.completion.chunk",
707
+ make_chat_chunk(
708
+ self._response_id,
709
+ self.model,
710
+ output,
711
+ index=index,
712
+ is_final=True,
713
+ ),
714
+ )
715
+ else:
716
+ # Original image_generation format
717
+ yield self._sse(
718
+ "image_generation.completed",
719
+ {
720
+ "type": "image_generation.completed",
721
+ self.response_field: output,
722
+ "created_at": int(time.time()),
723
+ "size": self.size,
724
+ "index": index,
725
+ "image_id": image_id,
726
+ "stage": "final",
727
+ "usage": {
728
+ "total_tokens": 0,
729
+ "input_tokens": 0,
730
+ "output_tokens": 0,
731
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
732
+ },
733
+ },
734
+ )
735
+
736
+ if self.chat_format:
737
+ if not self._id_generated:
738
+ self._response_id = make_response_id()
739
+ self._id_generated = True
740
+ if not emitted_chat_chunk:
741
+ yield self._sse(
742
+ "chat.completion.chunk",
743
+ make_chat_chunk(
744
+ self._response_id,
745
+ self.model,
746
+ "",
747
+ index=0,
748
+ is_final=True,
749
+ ),
750
+ )
751
+ yield "data: [DONE]\n\n"
752
+
753
+
754
+ class ImageWSCollectProcessor(ImageWSBaseProcessor):
755
+ """WebSocket image non-stream processor."""
756
+
757
+ def __init__(
758
+ self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json"
759
+ ):
760
+ super().__init__(model, token, response_format)
761
+ self.n = n
762
+
763
+ async def process(self, response: AsyncIterable[dict]) -> List[str]:
764
+ images: Dict[str, Dict] = {}
765
+
766
+ async for item in response:
767
+ if item.get("type") == "error":
768
+ message = item.get("error") or "Upstream error"
769
+ raise UpstreamException(message, details=item)
770
+ if item.get("type") != "image":
771
+ continue
772
+ image_id = item.get("image_id")
773
+ if not image_id:
774
+ continue
775
+ images[image_id] = self._pick_best(images.get(image_id), item)
776
+
777
+ selected = sorted(
778
+ [item for item in images.values() if item.get("is_final", False)],
779
+ key=lambda x: x.get("blob_size", 0),
780
+ reverse=True,
781
+ )
782
+ if self.n:
783
+ selected = selected[: self.n]
784
+
785
+ results: List[str] = []
786
+ for item in selected:
787
+ output = await self._to_output(item.get("image_id", ""), item)
788
+ if output:
789
+ results.append(output)
790
+
791
+ return results
792
+
793
+
794
+ __all__ = ["ImageGenerationService"]
app/services/grok/services/image_edit.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok image edit service.
3
+ """
4
+
5
+ import asyncio
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import AsyncGenerator, AsyncIterable, List, Union, Any
12
+
13
+ import orjson
14
+ from curl_cffi.requests.errors import RequestsError
15
+
16
+ from app.core.config import get_config
17
+ from app.core.exceptions import (
18
+ AppException,
19
+ ErrorType,
20
+ UpstreamException,
21
+ StreamIdleTimeoutError,
22
+ )
23
+ from app.core.logger import logger
24
+ from app.services.grok.utils.process import (
25
+ BaseProcessor,
26
+ _with_idle_timeout,
27
+ _normalize_line,
28
+ _collect_images,
29
+ _is_http2_error,
30
+ )
31
+ from app.services.grok.utils.upload import UploadService
32
+ from app.services.grok.utils.retry import pick_token, rate_limited
33
+ from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content
34
+ from app.services.grok.services.chat import GrokChatService
35
+ from app.services.grok.services.video import VideoService
36
+ from app.services.grok.utils.stream import wrap_stream_with_usage
37
+ from app.services.token import EffortType
38
+
39
+
40
+ @dataclass
41
+ class ImageEditResult:
42
+ stream: bool
43
+ data: Union[AsyncGenerator[str, None], List[str]]
44
+
45
+
46
+ class ImageEditService:
47
+ """Image edit orchestration service."""
48
+
49
+ async def edit(
50
+ self,
51
+ *,
52
+ token_mgr: Any,
53
+ token: str,
54
+ model_info: Any,
55
+ prompt: str,
56
+ images: List[str],
57
+ n: int,
58
+ response_format: str,
59
+ stream: bool,
60
+ chat_format: bool = False,
61
+ ) -> ImageEditResult:
62
+ if len(images) > 3:
63
+ logger.info(
64
+ "Image edit received %d references; using the most recent 3",
65
+ len(images),
66
+ )
67
+ images = images[-3:]
68
+
69
+ max_token_retries = int(get_config("retry.max_retry") or 3)
70
+ tried_tokens: set[str] = set()
71
+ last_error: Exception | None = None
72
+
73
+ for attempt in range(max_token_retries):
74
+ preferred = token if attempt == 0 else None
75
+ current_token = await pick_token(
76
+ token_mgr, model_info.model_id, tried_tokens, preferred=preferred
77
+ )
78
+ if not current_token:
79
+ if last_error:
80
+ raise last_error
81
+ raise AppException(
82
+ message="No available tokens. Please try again later.",
83
+ error_type=ErrorType.RATE_LIMIT.value,
84
+ code="rate_limit_exceeded",
85
+ status_code=429,
86
+ )
87
+
88
+ tried_tokens.add(current_token)
89
+ try:
90
+ image_urls = await self._upload_images(images, current_token)
91
+ parent_post_id = await self._get_parent_post_id(
92
+ current_token, image_urls
93
+ )
94
+
95
+ model_config_override = {
96
+ "modelMap": {
97
+ "imageEditModel": "imagine",
98
+ "imageEditModelConfig": {
99
+ "imageReferences": image_urls,
100
+ },
101
+ }
102
+ }
103
+ if parent_post_id:
104
+ model_config_override["modelMap"]["imageEditModelConfig"][
105
+ "parentPostId"
106
+ ] = parent_post_id
107
+
108
+ tool_overrides = {"imageGen": True}
109
+
110
+ if stream:
111
+ response = await GrokChatService().chat(
112
+ token=current_token,
113
+ message=prompt,
114
+ model=model_info.grok_model,
115
+ mode=None,
116
+ stream=True,
117
+ tool_overrides=tool_overrides,
118
+ model_config_override=model_config_override,
119
+ )
120
+ processor = ImageStreamProcessor(
121
+ model_info.model_id,
122
+ current_token,
123
+ n=n,
124
+ response_format=response_format,
125
+ chat_format=chat_format,
126
+ )
127
+ return ImageEditResult(
128
+ stream=True,
129
+ data=wrap_stream_with_usage(
130
+ processor.process(response),
131
+ token_mgr,
132
+ current_token,
133
+ model_info.model_id,
134
+ ),
135
+ )
136
+
137
+ images_out = await self._collect_images(
138
+ token=current_token,
139
+ prompt=prompt,
140
+ model_info=model_info,
141
+ n=n,
142
+ response_format=response_format,
143
+ tool_overrides=tool_overrides,
144
+ model_config_override=model_config_override,
145
+ )
146
+ try:
147
+ effort = (
148
+ EffortType.HIGH
149
+ if (model_info and model_info.cost.value == "high")
150
+ else EffortType.LOW
151
+ )
152
+ await token_mgr.consume(current_token, effort)
153
+ logger.debug(
154
+ f"Image edit completed, recorded usage (effort={effort.value})"
155
+ )
156
+ except Exception as e:
157
+ logger.warning(f"Failed to record image edit usage: {e}")
158
+ return ImageEditResult(stream=False, data=images_out)
159
+
160
+ except UpstreamException as e:
161
+ last_error = e
162
+ if rate_limited(e):
163
+ await token_mgr.mark_rate_limited(current_token)
164
+ logger.warning(
165
+ f"Token {current_token[:10]}... rate limited (429), "
166
+ f"trying next token (attempt {attempt + 1}/{max_token_retries})"
167
+ )
168
+ continue
169
+ raise
170
+
171
+ if last_error:
172
+ raise last_error
173
+ raise AppException(
174
+ message="No available tokens. Please try again later.",
175
+ error_type=ErrorType.RATE_LIMIT.value,
176
+ code="rate_limit_exceeded",
177
+ status_code=429,
178
+ )
179
+
180
+ async def _upload_images(self, images: List[str], token: str) -> List[str]:
181
+ image_urls: List[str] = []
182
+ upload_service = UploadService()
183
+ try:
184
+ for image in images:
185
+ _, file_uri = await upload_service.upload_file(image, token)
186
+ if file_uri:
187
+ if file_uri.startswith("http"):
188
+ image_urls.append(file_uri)
189
+ else:
190
+ image_urls.append(
191
+ f"https://assets.grok.com/{file_uri.lstrip('/')}"
192
+ )
193
+ finally:
194
+ await upload_service.close()
195
+
196
+ if not image_urls:
197
+ raise AppException(
198
+ message="Image upload failed",
199
+ error_type=ErrorType.SERVER.value,
200
+ code="upload_failed",
201
+ )
202
+
203
+ return image_urls
204
+
205
+ async def _get_parent_post_id(self, token: str, image_urls: List[str]) -> str:
206
+ parent_post_id = None
207
+ try:
208
+ media_service = VideoService()
209
+ parent_post_id = await media_service.create_image_post(token, image_urls[0])
210
+ logger.debug(f"Parent post ID: {parent_post_id}")
211
+ except Exception as e:
212
+ logger.warning(f"Create image post failed: {e}")
213
+
214
+ if parent_post_id:
215
+ return parent_post_id
216
+
217
+ for url in image_urls:
218
+ match = re.search(r"/generated/([a-f0-9-]+)/", url)
219
+ if match:
220
+ parent_post_id = match.group(1)
221
+ logger.debug(f"Parent post ID: {parent_post_id}")
222
+ break
223
+ match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url)
224
+ if match:
225
+ parent_post_id = match.group(1)
226
+ logger.debug(f"Parent post ID: {parent_post_id}")
227
+ break
228
+
229
+ return parent_post_id or ""
230
+
231
+ async def _collect_images(
232
+ self,
233
+ *,
234
+ token: str,
235
+ prompt: str,
236
+ model_info: Any,
237
+ n: int,
238
+ response_format: str,
239
+ tool_overrides: dict,
240
+ model_config_override: dict,
241
+ ) -> List[str]:
242
+ calls_needed = (n + 1) // 2
243
+
244
+ async def _call_edit():
245
+ response = await GrokChatService().chat(
246
+ token=token,
247
+ message=prompt,
248
+ model=model_info.grok_model,
249
+ mode=None,
250
+ stream=True,
251
+ tool_overrides=tool_overrides,
252
+ model_config_override=model_config_override,
253
+ )
254
+ processor = ImageCollectProcessor(
255
+ model_info.model_id, token, response_format=response_format
256
+ )
257
+ return await processor.process(response)
258
+
259
+ last_error: Exception | None = None
260
+ rate_limit_error: Exception | None = None
261
+
262
+ if calls_needed == 1:
263
+ all_images = await _call_edit()
264
+ else:
265
+ tasks = [_call_edit() for _ in range(calls_needed)]
266
+ results = await asyncio.gather(*tasks, return_exceptions=True)
267
+
268
+ all_images: List[str] = []
269
+ for result in results:
270
+ if isinstance(result, Exception):
271
+ logger.error(f"Concurrent call failed: {result}")
272
+ last_error = result
273
+ if rate_limited(result):
274
+ rate_limit_error = result
275
+ elif isinstance(result, list):
276
+ all_images.extend(result)
277
+
278
+ if not all_images:
279
+ if rate_limit_error:
280
+ raise rate_limit_error
281
+ if last_error:
282
+ raise last_error
283
+ raise UpstreamException(
284
+ "Image edit returned no results", details={"error": "empty_result"}
285
+ )
286
+
287
+ if len(all_images) >= n:
288
+ return all_images[:n]
289
+
290
+ selected_images = all_images.copy()
291
+ while len(selected_images) < n:
292
+ selected_images.append("error")
293
+ return selected_images
294
+
295
+
296
+ class ImageStreamProcessor(BaseProcessor):
297
+ """HTTP image stream processor."""
298
+
299
+ def __init__(
300
+ self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json", chat_format: bool = False
301
+ ):
302
+ super().__init__(model, token)
303
+ self.partial_index = 0
304
+ self.n = n
305
+ self.target_index = 0 if n == 1 else None
306
+ self.response_format = response_format
307
+ self.chat_format = chat_format
308
+ self._id_generated = False
309
+ self._response_id = ""
310
+ if response_format == "url":
311
+ self.response_field = "url"
312
+ elif response_format == "base64":
313
+ self.response_field = "base64"
314
+ else:
315
+ self.response_field = "b64_json"
316
+
317
+ def _sse(self, event: str, data: dict) -> str:
318
+ """Build SSE response."""
319
+ return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
320
+
321
+ async def process(
322
+ self, response: AsyncIterable[bytes]
323
+ ) -> AsyncGenerator[str, None]:
324
+ """Process stream response."""
325
+ final_images = []
326
+ emitted_chat_chunk = False
327
+ idle_timeout = get_config("image.stream_timeout")
328
+
329
+ try:
330
+ async for line in _with_idle_timeout(response, idle_timeout, self.model):
331
+ line = _normalize_line(line)
332
+ if not line:
333
+ continue
334
+ try:
335
+ data = orjson.loads(line)
336
+ except orjson.JSONDecodeError:
337
+ continue
338
+
339
+ resp = data.get("result", {}).get("response", {})
340
+
341
+ # Image generation progress
342
+ if img := resp.get("streamingImageGenerationResponse"):
343
+ image_index = img.get("imageIndex", 0)
344
+ progress = img.get("progress", 0)
345
+
346
+ if self.n == 1 and image_index != self.target_index:
347
+ continue
348
+
349
+ out_index = 0 if self.n == 1 else image_index
350
+
351
+ if not self.chat_format:
352
+ yield self._sse(
353
+ "image_generation.partial_image",
354
+ {
355
+ "type": "image_generation.partial_image",
356
+ self.response_field: "",
357
+ "index": out_index,
358
+ "progress": progress,
359
+ },
360
+ )
361
+ continue
362
+
363
+ # modelResponse
364
+ if mr := resp.get("modelResponse"):
365
+ if urls := _collect_images(mr):
366
+ for url in urls:
367
+ if self.response_format == "url":
368
+ processed = await self.process_url(url, "image")
369
+ if processed:
370
+ final_images.append(processed)
371
+ continue
372
+ try:
373
+ dl_service = self._get_dl()
374
+ base64_data = await dl_service.parse_b64(
375
+ url, self.token, "image"
376
+ )
377
+ if base64_data:
378
+ if "," in base64_data:
379
+ b64 = base64_data.split(",", 1)[1]
380
+ else:
381
+ b64 = base64_data
382
+ final_images.append(b64)
383
+ except Exception as e:
384
+ logger.warning(
385
+ f"Failed to convert image to base64, falling back to URL: {e}"
386
+ )
387
+ processed = await self.process_url(url, "image")
388
+ if processed:
389
+ final_images.append(processed)
390
+ continue
391
+
392
+ for index, img_data in enumerate(final_images):
393
+ if self.n == 1:
394
+ if index != self.target_index:
395
+ continue
396
+ out_index = 0
397
+ else:
398
+ out_index = index
399
+
400
+ # Wrap in markdown format for chat
401
+ output = img_data
402
+ if self.chat_format and output:
403
+ output = wrap_image_content(output, self.response_format)
404
+
405
+ if not self._id_generated:
406
+ self._response_id = make_response_id()
407
+ self._id_generated = True
408
+
409
+ if self.chat_format:
410
+ # OpenAI ChatCompletion chunk format
411
+ emitted_chat_chunk = True
412
+ yield self._sse(
413
+ "chat.completion.chunk",
414
+ make_chat_chunk(
415
+ self._response_id,
416
+ self.model,
417
+ output,
418
+ index=out_index,
419
+ is_final=True,
420
+ ),
421
+ )
422
+ else:
423
+ # Original image_generation format
424
+ yield self._sse(
425
+ "image_generation.completed",
426
+ {
427
+ "type": "image_generation.completed",
428
+ self.response_field: img_data,
429
+ "index": out_index,
430
+ "usage": {
431
+ "total_tokens": 0,
432
+ "input_tokens": 0,
433
+ "output_tokens": 0,
434
+ "input_tokens_details": {
435
+ "text_tokens": 0,
436
+ "image_tokens": 0,
437
+ },
438
+ },
439
+ },
440
+ )
441
+
442
+ if self.chat_format:
443
+ if not self._id_generated:
444
+ self._response_id = make_response_id()
445
+ self._id_generated = True
446
+ if not emitted_chat_chunk:
447
+ yield self._sse(
448
+ "chat.completion.chunk",
449
+ make_chat_chunk(
450
+ self._response_id,
451
+ self.model,
452
+ "",
453
+ index=0,
454
+ is_final=True,
455
+ ),
456
+ )
457
+ yield "data: [DONE]\n\n"
458
+ except asyncio.CancelledError:
459
+ logger.debug("Image stream cancelled by client")
460
+ except StreamIdleTimeoutError as e:
461
+ raise UpstreamException(
462
+ message=f"Image stream idle timeout after {e.idle_seconds}s",
463
+ status_code=504,
464
+ details={
465
+ "error": str(e),
466
+ "type": "stream_idle_timeout",
467
+ "idle_seconds": e.idle_seconds,
468
+ },
469
+ )
470
+ except RequestsError as e:
471
+ if _is_http2_error(e):
472
+ logger.warning(f"HTTP/2 stream error in image: {e}")
473
+ raise UpstreamException(
474
+ message="Upstream connection closed unexpectedly",
475
+ status_code=502,
476
+ details={"error": str(e), "type": "http2_stream_error"},
477
+ )
478
+ logger.error(f"Image stream request error: {e}")
479
+ raise UpstreamException(
480
+ message=f"Upstream request failed: {e}",
481
+ status_code=502,
482
+ details={"error": str(e)},
483
+ )
484
+ except Exception as e:
485
+ logger.error(
486
+ f"Image stream processing error: {e}",
487
+ extra={"error_type": type(e).__name__},
488
+ )
489
+ raise
490
+ finally:
491
+ await self.close()
492
+
493
+
494
+ class ImageCollectProcessor(BaseProcessor):
495
+ """HTTP image non-stream processor."""
496
+
497
+ def __init__(self, model: str, token: str = "", response_format: str = "b64_json"):
498
+ if response_format == "base64":
499
+ response_format = "b64_json"
500
+ super().__init__(model, token)
501
+ self.response_format = response_format
502
+
503
+ async def process(self, response: AsyncIterable[bytes]) -> List[str]:
504
+ """Process and collect images."""
505
+ images = []
506
+ idle_timeout = get_config("image.stream_timeout")
507
+
508
+ try:
509
+ async for line in _with_idle_timeout(response, idle_timeout, self.model):
510
+ line = _normalize_line(line)
511
+ if not line:
512
+ continue
513
+ try:
514
+ data = orjson.loads(line)
515
+ except orjson.JSONDecodeError:
516
+ continue
517
+
518
+ resp = data.get("result", {}).get("response", {})
519
+
520
+ if mr := resp.get("modelResponse"):
521
+ if urls := _collect_images(mr):
522
+ for url in urls:
523
+ if self.response_format == "url":
524
+ processed = await self.process_url(url, "image")
525
+ if processed:
526
+ images.append(processed)
527
+ continue
528
+ try:
529
+ dl_service = self._get_dl()
530
+ base64_data = await dl_service.parse_b64(
531
+ url, self.token, "image"
532
+ )
533
+ if base64_data:
534
+ if "," in base64_data:
535
+ b64 = base64_data.split(",", 1)[1]
536
+ else:
537
+ b64 = base64_data
538
+ images.append(b64)
539
+ except Exception as e:
540
+ logger.warning(
541
+ f"Failed to convert image to base64, falling back to URL: {e}"
542
+ )
543
+ processed = await self.process_url(url, "image")
544
+ if processed:
545
+ images.append(processed)
546
+
547
+ except asyncio.CancelledError:
548
+ logger.debug("Image collect cancelled by client")
549
+ except StreamIdleTimeoutError as e:
550
+ logger.warning(f"Image collect idle timeout: {e}")
551
+ except RequestsError as e:
552
+ if _is_http2_error(e):
553
+ logger.warning(f"HTTP/2 stream error in image collect: {e}")
554
+ else:
555
+ logger.error(f"Image collect request error: {e}")
556
+ except Exception as e:
557
+ logger.error(
558
+ f"Image collect processing error: {e}",
559
+ extra={"error_type": type(e).__name__},
560
+ )
561
+ finally:
562
+ await self.close()
563
+
564
+ return images
565
+
566
+
567
+ __all__ = ["ImageEditService", "ImageEditResult"]
app/services/grok/services/model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok 模型管理服务
3
+ """
4
+
5
+ from enum import Enum
6
+ from typing import Optional, Tuple, List
7
+ from pydantic import BaseModel, Field
8
+
9
+ from app.core.exceptions import ValidationException
10
+
11
+
12
+ class Tier(str, Enum):
13
+ """模型档位"""
14
+
15
+ BASIC = "basic"
16
+ SUPER = "super"
17
+
18
+
19
+ class Cost(str, Enum):
20
+ """计费类型"""
21
+
22
+ LOW = "low"
23
+ HIGH = "high"
24
+
25
+
26
+ class ModelInfo(BaseModel):
27
+ """模型信息"""
28
+
29
+ model_id: str
30
+ grok_model: str
31
+ model_mode: str
32
+ tier: Tier = Field(default=Tier.BASIC)
33
+ cost: Cost = Field(default=Cost.LOW)
34
+ display_name: str
35
+ description: str = ""
36
+ is_image: bool = False
37
+ is_image_edit: bool = False
38
+ is_video: bool = False
39
+
40
+
41
+ class ModelService:
42
+ """模型管理服务"""
43
+
44
+ MODELS = [
45
+ ModelInfo(
46
+ model_id="grok-3",
47
+ grok_model="grok-3",
48
+ model_mode="MODEL_MODE_GROK_3",
49
+ tier=Tier.BASIC,
50
+ cost=Cost.LOW,
51
+ display_name="GROK-3",
52
+ is_image=False,
53
+ is_image_edit=False,
54
+ is_video=False,
55
+ ),
56
+ ModelInfo(
57
+ model_id="grok-3-mini",
58
+ grok_model="grok-3",
59
+ model_mode="MODEL_MODE_GROK_3_MINI_THINKING",
60
+ tier=Tier.BASIC,
61
+ cost=Cost.LOW,
62
+ display_name="GROK-3-MINI",
63
+ is_image=False,
64
+ is_image_edit=False,
65
+ is_video=False,
66
+ ),
67
+ ModelInfo(
68
+ model_id="grok-3-thinking",
69
+ grok_model="grok-3",
70
+ model_mode="MODEL_MODE_GROK_3_THINKING",
71
+ tier=Tier.BASIC,
72
+ cost=Cost.LOW,
73
+ display_name="GROK-3-THINKING",
74
+ is_image=False,
75
+ is_image_edit=False,
76
+ is_video=False,
77
+ ),
78
+ ModelInfo(
79
+ model_id="grok-4",
80
+ grok_model="grok-4",
81
+ model_mode="MODEL_MODE_GROK_4",
82
+ tier=Tier.BASIC,
83
+ cost=Cost.LOW,
84
+ display_name="GROK-4",
85
+ is_image=False,
86
+ is_image_edit=False,
87
+ is_video=False,
88
+ ),
89
+ ModelInfo(
90
+ model_id="grok-4-mini",
91
+ grok_model="grok-4-mini",
92
+ model_mode="MODEL_MODE_GROK_4_MINI_THINKING",
93
+ tier=Tier.BASIC,
94
+ cost=Cost.LOW,
95
+ display_name="GROK-4-MINI",
96
+ is_image=False,
97
+ is_image_edit=False,
98
+ is_video=False,
99
+ ),
100
+ ModelInfo(
101
+ model_id="grok-4-thinking",
102
+ grok_model="grok-4",
103
+ model_mode="MODEL_MODE_GROK_4_THINKING",
104
+ tier=Tier.BASIC,
105
+ cost=Cost.LOW,
106
+ display_name="GROK-4-THINKING",
107
+ is_image=False,
108
+ is_image_edit=False,
109
+ is_video=False,
110
+ ),
111
+ ModelInfo(
112
+ model_id="grok-4-heavy",
113
+ grok_model="grok-4",
114
+ model_mode="MODEL_MODE_HEAVY",
115
+ tier=Tier.SUPER,
116
+ cost=Cost.HIGH,
117
+ display_name="GROK-4-HEAVY",
118
+ is_image=False,
119
+ is_image_edit=False,
120
+ is_video=False,
121
+ ),
122
+ ModelInfo(
123
+ model_id="grok-4.1-mini",
124
+ grok_model="grok-4-1-thinking-1129",
125
+ model_mode="MODEL_MODE_GROK_4_1_MINI_THINKING",
126
+ tier=Tier.BASIC,
127
+ cost=Cost.LOW,
128
+ display_name="GROK-4.1-MINI",
129
+ is_image=False,
130
+ is_image_edit=False,
131
+ is_video=False,
132
+ ),
133
+ ModelInfo(
134
+ model_id="grok-4.1-fast",
135
+ grok_model="grok-4-1-thinking-1129",
136
+ model_mode="MODEL_MODE_FAST",
137
+ tier=Tier.BASIC,
138
+ cost=Cost.LOW,
139
+ display_name="GROK-4.1-FAST",
140
+ is_image=False,
141
+ is_image_edit=False,
142
+ is_video=False,
143
+ ),
144
+ ModelInfo(
145
+ model_id="grok-4.1-expert",
146
+ grok_model="grok-4-1-thinking-1129",
147
+ model_mode="MODEL_MODE_EXPERT",
148
+ tier=Tier.BASIC,
149
+ cost=Cost.HIGH,
150
+ display_name="GROK-4.1-EXPERT",
151
+ is_image=False,
152
+ is_image_edit=False,
153
+ is_video=False,
154
+ ),
155
+ ModelInfo(
156
+ model_id="grok-4.1-thinking",
157
+ grok_model="grok-4-1-thinking-1129",
158
+ model_mode="MODEL_MODE_GROK_4_1_THINKING",
159
+ tier=Tier.BASIC,
160
+ cost=Cost.HIGH,
161
+ display_name="GROK-4.1-THINKING",
162
+ is_image=False,
163
+ is_image_edit=False,
164
+ is_video=False,
165
+ ),
166
+ ModelInfo(
167
+ model_id="grok-4.20-beta",
168
+ grok_model="grok-420",
169
+ model_mode="MODEL_MODE_GROK_420",
170
+ tier=Tier.BASIC,
171
+ cost=Cost.LOW,
172
+ display_name="GROK-4.20-BETA",
173
+ is_image=False,
174
+ is_image_edit=False,
175
+ is_video=False,
176
+ ),
177
+ ModelInfo(
178
+ model_id="grok-imagine-1.0-fast",
179
+ grok_model="grok-3",
180
+ model_mode="MODEL_MODE_FAST",
181
+ tier=Tier.BASIC,
182
+ cost=Cost.HIGH,
183
+ display_name="Grok Image Fast",
184
+ description="Imagine waterfall image generation model for chat completions",
185
+ is_image=True,
186
+ is_image_edit=False,
187
+ is_video=False,
188
+ ),
189
+ ModelInfo(
190
+ model_id="grok-imagine-1.0",
191
+ grok_model="grok-3",
192
+ model_mode="MODEL_MODE_FAST",
193
+ tier=Tier.BASIC,
194
+ cost=Cost.HIGH,
195
+ display_name="Grok Image",
196
+ description="Image generation model",
197
+ is_image=True,
198
+ is_image_edit=False,
199
+ is_video=False,
200
+ ),
201
+ ModelInfo(
202
+ model_id="grok-imagine-1.0-edit",
203
+ grok_model="imagine-image-edit",
204
+ model_mode="MODEL_MODE_FAST",
205
+ tier=Tier.BASIC,
206
+ cost=Cost.HIGH,
207
+ display_name="Grok Image Edit",
208
+ description="Image edit model",
209
+ is_image=False,
210
+ is_image_edit=True,
211
+ is_video=False,
212
+ ),
213
+ ModelInfo(
214
+ model_id="grok-imagine-1.0-video",
215
+ grok_model="grok-3",
216
+ model_mode="MODEL_MODE_FAST",
217
+ tier=Tier.BASIC,
218
+ cost=Cost.HIGH,
219
+ display_name="Grok Video",
220
+ description="Video generation model",
221
+ is_image=False,
222
+ is_image_edit=False,
223
+ is_video=True,
224
+ ),
225
+ ]
226
+
227
+ _map = {m.model_id: m for m in MODELS}
228
+
229
+ @classmethod
230
+ def get(cls, model_id: str) -> Optional[ModelInfo]:
231
+ """获取模型信息"""
232
+ return cls._map.get(model_id)
233
+
234
+ @classmethod
235
+ def list(cls) -> list[ModelInfo]:
236
+ """获取所有模型"""
237
+ return list(cls._map.values())
238
+
239
+ @classmethod
240
+ def valid(cls, model_id: str) -> bool:
241
+ """模型是否有效"""
242
+ return model_id in cls._map
243
+
244
+ @classmethod
245
+ def to_grok(cls, model_id: str) -> Tuple[str, str]:
246
+ """转换为 Grok 参数"""
247
+ model = cls.get(model_id)
248
+ if not model:
249
+ raise ValidationException(f"Invalid model ID: {model_id}")
250
+ return model.grok_model, model.model_mode
251
+
252
+ @classmethod
253
+ def pool_for_model(cls, model_id: str) -> str:
254
+ """根据模型选择 Token 池"""
255
+ model = cls.get(model_id)
256
+ if model and model.tier == Tier.SUPER:
257
+ return "ssoSuper"
258
+ return "ssoBasic"
259
+
260
+ @classmethod
261
+ def pool_candidates_for_model(cls, model_id: str) -> List[str]:
262
+ """按优先级返回可用 Token 池列表"""
263
+ model = cls.get(model_id)
264
+ if model and model.tier == Tier.SUPER:
265
+ return ["ssoSuper"]
266
+ # 基础模型优先使用 basic 池,缺失时可回退到 super 池
267
+ return ["ssoBasic", "ssoSuper"]
268
+
269
+
270
+ __all__ = ["ModelService"]
app/services/grok/services/responses.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Responses API bridge service (OpenAI-compatible).
3
+ """
4
+
5
+ import time
6
+ import uuid
7
+ from typing import Any, AsyncGenerator, Dict, List, Optional
8
+
9
+ import orjson
10
+
11
+ from app.services.grok.services.chat import ChatService
12
+ from app.services.grok.utils import process as proc_base
13
+
14
+
15
+ _TOOL_OUTPUT_TYPES = {
16
+ "tool_output",
17
+ "function_call_output",
18
+ "tool_call_output",
19
+ "input_tool_output",
20
+ }
21
+
22
+ _BUILTIN_TOOL_TYPES = {
23
+ "web_search",
24
+ "web_search_2025_08_26",
25
+ "file_search",
26
+ "code_interpreter",
27
+ }
28
+
29
+
30
+ def _now_ts() -> int:
31
+ return int(time.time())
32
+
33
+
34
+ def _new_response_id() -> str:
35
+ return f"resp_{uuid.uuid4().hex[:24]}"
36
+
37
+
38
+ def _new_message_id() -> str:
39
+ return f"msg_{uuid.uuid4().hex[:24]}"
40
+
41
+
42
+ def _new_tool_call_id() -> str:
43
+ return f"call_{uuid.uuid4().hex[:24]}"
44
+
45
+
46
+ def _new_function_call_id() -> str:
47
+ return f"fc_{uuid.uuid4().hex[:24]}"
48
+
49
+
50
+ def _normalize_tool_choice(tool_choice: Any) -> Any:
51
+ if isinstance(tool_choice, dict):
52
+ t_type = tool_choice.get("type")
53
+ if t_type and t_type != "function":
54
+ return {"type": "function", "function": {"name": t_type}}
55
+ return tool_choice
56
+
57
+
58
+ def _normalize_tools_for_chat(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
59
+ if not tools:
60
+ return None
61
+ normalized: List[Dict[str, Any]] = []
62
+ for tool in tools:
63
+ if not isinstance(tool, dict):
64
+ continue
65
+ tool_type = tool.get("type")
66
+ if tool_type == "function":
67
+ normalized.append(tool)
68
+ continue
69
+ if tool_type in _BUILTIN_TOOL_TYPES:
70
+ if tool_type.startswith("web_search"):
71
+ normalized.append(
72
+ {
73
+ "type": "function",
74
+ "function": {
75
+ "name": tool_type,
76
+ "description": "Search the web for information and return results.",
77
+ "parameters": {
78
+ "type": "object",
79
+ "properties": {"query": {"type": "string"}},
80
+ "required": ["query"],
81
+ },
82
+ },
83
+ }
84
+ )
85
+ elif tool_type == "file_search":
86
+ normalized.append(
87
+ {
88
+ "type": "function",
89
+ "function": {
90
+ "name": tool_type,
91
+ "description": "Search provided files for relevant information.",
92
+ "parameters": {
93
+ "type": "object",
94
+ "properties": {"query": {"type": "string"}},
95
+ "required": ["query"],
96
+ },
97
+ },
98
+ }
99
+ )
100
+ elif tool_type == "code_interpreter":
101
+ normalized.append(
102
+ {
103
+ "type": "function",
104
+ "function": {
105
+ "name": tool_type,
106
+ "description": "Execute code to solve tasks and return results.",
107
+ "parameters": {
108
+ "type": "object",
109
+ "properties": {"code": {"type": "string"}},
110
+ "required": ["code"],
111
+ },
112
+ },
113
+ }
114
+ )
115
+ return normalized or None
116
+
117
+
118
+ def _content_item_from_input(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
119
+ if not isinstance(item, dict):
120
+ return None
121
+ item_type = item.get("type")
122
+
123
+ if item_type in {"input_text", "text", "output_text"}:
124
+ text = item.get("text") or item.get("content") or ""
125
+ return {"type": "text", "text": text}
126
+
127
+ if item_type in {"input_image", "image", "image_url", "output_image"}:
128
+ image_url = item.get("image_url")
129
+ url = ""
130
+ detail = None
131
+ if isinstance(image_url, dict):
132
+ url = image_url.get("url") or ""
133
+ detail = image_url.get("detail")
134
+ elif isinstance(image_url, str):
135
+ url = image_url
136
+ else:
137
+ url = item.get("url") or item.get("image") or ""
138
+
139
+ if not url:
140
+ return None
141
+ image_payload = {"url": url}
142
+ if detail:
143
+ image_payload["detail"] = detail
144
+ return {"type": "image_url", "image_url": image_payload}
145
+
146
+ if item_type in {"input_file", "file"}:
147
+ file_data = item.get("file_data")
148
+ file_id = item.get("file_id")
149
+ if not file_data and isinstance(item.get("file"), dict):
150
+ file_data = item["file"].get("file_data")
151
+ file_id = item["file"].get("file_id")
152
+ file_payload: Dict[str, Any] = {}
153
+ if file_data:
154
+ file_payload["file_data"] = file_data
155
+ if file_id:
156
+ file_payload["file_id"] = file_id
157
+ if not file_payload:
158
+ return None
159
+ return {"type": "file", "file": file_payload}
160
+
161
+ if item_type in {"input_audio", "audio"}:
162
+ audio = item.get("audio") or {}
163
+ data = audio.get("data") or item.get("data")
164
+ if not data:
165
+ return None
166
+ return {"type": "input_audio", "input_audio": {"data": data}}
167
+
168
+ return None
169
+
170
+
171
+ def _message_from_item(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
172
+ if not isinstance(item, dict):
173
+ return None
174
+
175
+ if item.get("type") == "message":
176
+ role = item.get("role") or "user"
177
+ content = item.get("content")
178
+ return {"role": role, "content": _coerce_content(content)}
179
+
180
+ if "role" in item and "content" in item:
181
+ return {"role": item.get("role") or "user", "content": _coerce_content(item.get("content"))}
182
+
183
+ return None
184
+
185
+
186
+ def _coerce_content(content: Any) -> Any:
187
+ if content is None:
188
+ return ""
189
+ if isinstance(content, str):
190
+ return content
191
+ if isinstance(content, dict):
192
+ content = [content]
193
+ if isinstance(content, list):
194
+ blocks: List[Dict[str, Any]] = []
195
+ for item in content:
196
+ if isinstance(item, dict) and item.get("type") in {"input_text", "output_text"}:
197
+ blocks.append({"type": "text", "text": item.get("text", "")})
198
+ continue
199
+ block = _content_item_from_input(item) if isinstance(item, dict) else None
200
+ if block:
201
+ blocks.append(block)
202
+ return blocks if blocks else ""
203
+ return str(content)
204
+
205
+
206
+ def _coerce_input_to_messages(input_value: Any) -> List[Dict[str, Any]]:
207
+ if input_value is None:
208
+ return []
209
+ if isinstance(input_value, str):
210
+ return [{"role": "user", "content": input_value}]
211
+
212
+ if isinstance(input_value, dict):
213
+ msg = _message_from_item(input_value)
214
+ if msg:
215
+ return [msg]
216
+ content_item = _content_item_from_input(input_value)
217
+ if content_item:
218
+ return [{"role": "user", "content": [content_item]}]
219
+ return []
220
+
221
+ if not isinstance(input_value, list):
222
+ return [{"role": "user", "content": str(input_value)}]
223
+
224
+ messages: List[Dict[str, Any]] = []
225
+ pending_blocks: List[Dict[str, Any]] = []
226
+
227
+ def _flush_pending():
228
+ nonlocal pending_blocks
229
+ if pending_blocks:
230
+ messages.append({"role": "user", "content": pending_blocks})
231
+ pending_blocks = []
232
+
233
+ for item in input_value:
234
+ if isinstance(item, dict):
235
+ msg = _message_from_item(item)
236
+ if msg:
237
+ _flush_pending()
238
+ messages.append(msg)
239
+ continue
240
+
241
+ item_type = item.get("type")
242
+ if item_type in _TOOL_OUTPUT_TYPES:
243
+ _flush_pending()
244
+ call_id = (
245
+ item.get("call_id")
246
+ or item.get("tool_call_id")
247
+ or item.get("id")
248
+ or _new_tool_call_id()
249
+ )
250
+ output = item.get("output") or item.get("content") or ""
251
+ messages.append({"role": "tool", "tool_call_id": call_id, "content": output})
252
+ continue
253
+
254
+ block = _content_item_from_input(item)
255
+ if block:
256
+ pending_blocks.append(block)
257
+ continue
258
+
259
+ if isinstance(item, str):
260
+ pending_blocks.append({"type": "text", "text": item})
261
+
262
+ _flush_pending()
263
+ return messages
264
+
265
+
266
+ def _build_output_message(
267
+ text: str,
268
+ *,
269
+ message_id: Optional[str] = None,
270
+ status: str = "completed",
271
+ ) -> Dict[str, Any]:
272
+ message_id = message_id or _new_message_id()
273
+ return {
274
+ "id": message_id,
275
+ "type": "message",
276
+ "role": "assistant",
277
+ "status": status,
278
+ "content": [
279
+ {
280
+ "type": "output_text",
281
+ "text": text,
282
+ "annotations": [],
283
+ }
284
+ ],
285
+ }
286
+
287
+
288
+ def _build_output_tool_call(
289
+ tool_call: Dict[str, Any],
290
+ *,
291
+ item_id: Optional[str] = None,
292
+ status: str = "completed",
293
+ ) -> Dict[str, Any]:
294
+ fn = tool_call.get("function") or {}
295
+ call_id = tool_call.get("id") or _new_tool_call_id()
296
+ item_id = item_id or _new_function_call_id()
297
+ return {
298
+ "id": item_id,
299
+ "type": "function_call",
300
+ "status": status,
301
+ "call_id": call_id,
302
+ "name": fn.get("name"),
303
+ "arguments": fn.get("arguments"),
304
+ }
305
+
306
+
307
+ def _build_response_object(
308
+ *,
309
+ model: str,
310
+ output_text: Optional[str] = None,
311
+ tool_calls: Optional[List[Dict[str, Any]]] = None,
312
+ response_id: Optional[str] = None,
313
+ usage: Optional[Dict[str, Any]] = None,
314
+ created_at: Optional[int] = None,
315
+ completed_at: Optional[int] = None,
316
+ status: str = "completed",
317
+ instructions: Optional[str] = None,
318
+ max_output_tokens: Optional[int] = None,
319
+ parallel_tool_calls: Optional[bool] = None,
320
+ previous_response_id: Optional[str] = None,
321
+ reasoning_effort: Optional[str] = None,
322
+ store: Optional[bool] = None,
323
+ temperature: Optional[float] = None,
324
+ tool_choice: Optional[Any] = None,
325
+ tools: Optional[List[Dict[str, Any]]] = None,
326
+ top_p: Optional[float] = None,
327
+ truncation: Optional[str] = None,
328
+ user: Optional[str] = None,
329
+ metadata: Optional[Dict[str, Any]] = None,
330
+ ) -> Dict[str, Any]:
331
+ response_id = response_id or _new_response_id()
332
+ created_at = created_at or _now_ts()
333
+ if status == "completed" and completed_at is None:
334
+ completed_at = _now_ts()
335
+
336
+ output: List[Dict[str, Any]] = []
337
+ if output_text is not None:
338
+ output.append(_build_output_message(output_text))
339
+
340
+ if tool_calls:
341
+ for call in tool_calls:
342
+ output.append(_build_output_tool_call(call))
343
+
344
+ return {
345
+ "id": response_id,
346
+ "object": "response",
347
+ "created_at": created_at,
348
+ "completed_at": completed_at,
349
+ "status": status,
350
+ "error": None,
351
+ "incomplete_details": None,
352
+ "instructions": instructions,
353
+ "max_output_tokens": max_output_tokens,
354
+ "model": model,
355
+ "output": output,
356
+ "parallel_tool_calls": True if parallel_tool_calls is None else parallel_tool_calls,
357
+ "previous_response_id": previous_response_id,
358
+ "reasoning": {"effort": reasoning_effort, "summary": None},
359
+ "store": True if store is None else store,
360
+ "temperature": 1.0 if temperature is None else temperature,
361
+ "text": {"format": {"type": "text"}},
362
+ "tool_choice": tool_choice or "auto",
363
+ "tools": tools or [],
364
+ "top_p": 1.0 if top_p is None else top_p,
365
+ "truncation": truncation or "disabled",
366
+ "usage": usage,
367
+ "user": user,
368
+ "metadata": metadata or {},
369
+ }
370
+
371
+
372
+ class ResponseStreamAdapter:
373
+ def __init__(
374
+ self,
375
+ *,
376
+ model: str,
377
+ response_id: str,
378
+ created_at: int,
379
+ instructions: Optional[str],
380
+ max_output_tokens: Optional[int],
381
+ parallel_tool_calls: Optional[bool],
382
+ previous_response_id: Optional[str],
383
+ reasoning_effort: Optional[str],
384
+ store: Optional[bool],
385
+ temperature: Optional[float],
386
+ tool_choice: Optional[Any],
387
+ tools: Optional[List[Dict[str, Any]]],
388
+ top_p: Optional[float],
389
+ truncation: Optional[str],
390
+ user: Optional[str],
391
+ metadata: Optional[Dict[str, Any]],
392
+ ):
393
+ self.model = model
394
+ self.response_id = response_id
395
+ self.created_at = created_at
396
+ self.instructions = instructions
397
+ self.max_output_tokens = max_output_tokens
398
+ self.parallel_tool_calls = parallel_tool_calls
399
+ self.previous_response_id = previous_response_id
400
+ self.reasoning_effort = reasoning_effort
401
+ self.store = store
402
+ self.temperature = temperature
403
+ self.tool_choice = tool_choice
404
+ self.tools = tools
405
+ self.top_p = top_p
406
+ self.truncation = truncation
407
+ self.user = user
408
+ self.metadata = metadata
409
+
410
+ self.output_text_parts: List[str] = []
411
+ self.tool_calls_by_index: Dict[int, Dict[str, Any]] = {}
412
+ self.tool_items: Dict[int, Dict[str, Any]] = {}
413
+ self.next_output_index = 0
414
+ self.content_index = 0
415
+ self.message_id = _new_message_id()
416
+ self.message_started = False
417
+ self.message_output_index: Optional[int] = None
418
+
419
+ def _event(self, event_type: str, payload: Dict[str, Any]) -> str:
420
+ return f"event: {event_type}\ndata: {orjson.dumps(payload).decode()}\n\n"
421
+
422
+ def _response_payload(self, *, status: str, output_text: Optional[str], usage: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
423
+ tool_calls = None
424
+ if status == "completed" and self.tool_calls_by_index:
425
+ tool_calls = [
426
+ self.tool_calls_by_index[idx]
427
+ for idx in sorted(self.tool_calls_by_index.keys())
428
+ ]
429
+ return _build_response_object(
430
+ model=self.model,
431
+ output_text=output_text,
432
+ tool_calls=tool_calls,
433
+ response_id=self.response_id,
434
+ usage=usage,
435
+ created_at=self.created_at,
436
+ status=status,
437
+ instructions=self.instructions,
438
+ max_output_tokens=self.max_output_tokens,
439
+ parallel_tool_calls=self.parallel_tool_calls,
440
+ previous_response_id=self.previous_response_id,
441
+ reasoning_effort=self.reasoning_effort,
442
+ store=self.store,
443
+ temperature=self.temperature,
444
+ tool_choice=self.tool_choice,
445
+ tools=self.tools,
446
+ top_p=self.top_p,
447
+ truncation=self.truncation,
448
+ user=self.user,
449
+ metadata=self.metadata,
450
+ )
451
+
452
+ def _alloc_output_index(self) -> int:
453
+ idx = self.next_output_index
454
+ self.next_output_index += 1
455
+ return idx
456
+
457
+ def created_event(self) -> str:
458
+ payload = {
459
+ "type": "response.created",
460
+ "response": self._response_payload(status="in_progress", output_text=None, usage=None),
461
+ }
462
+ return self._event("response.created", payload)
463
+
464
+ def in_progress_event(self) -> str:
465
+ payload = {
466
+ "type": "response.in_progress",
467
+ "response": self._response_payload(status="in_progress", output_text=None, usage=None),
468
+ }
469
+ return self._event("response.in_progress", payload)
470
+
471
+ def ensure_message_started(self) -> List[str]:
472
+ if self.message_started:
473
+ return []
474
+ self.message_started = True
475
+ self.message_output_index = self._alloc_output_index()
476
+ item = _build_output_message("", message_id=self.message_id, status="in_progress")
477
+ item["content"] = []
478
+ events = [
479
+ self._event(
480
+ "response.output_item.added",
481
+ {
482
+ "type": "response.output_item.added",
483
+ "response_id": self.response_id,
484
+ "output_index": self.message_output_index,
485
+ "item": item,
486
+ },
487
+ ),
488
+ self._event(
489
+ "response.content_part.added",
490
+ {
491
+ "type": "response.content_part.added",
492
+ "response_id": self.response_id,
493
+ "item_id": self.message_id,
494
+ "output_index": self.message_output_index,
495
+ "content_index": self.content_index,
496
+ "part": {"type": "output_text", "text": "", "annotations": []},
497
+ },
498
+ ),
499
+ ]
500
+ return events
501
+
502
+ def output_delta_event(self, delta: str) -> str:
503
+ return self._event(
504
+ "response.output_text.delta",
505
+ {
506
+ "type": "response.output_text.delta",
507
+ "response_id": self.response_id,
508
+ "item_id": self.message_id,
509
+ "output_index": self.message_output_index,
510
+ "content_index": self.content_index,
511
+ "delta": delta,
512
+ },
513
+ )
514
+
515
+ def output_done_events(self, text: str) -> List[str]:
516
+ if self.message_output_index is None:
517
+ return []
518
+ return [
519
+ self._event(
520
+ "response.output_text.done",
521
+ {
522
+ "type": "response.output_text.done",
523
+ "response_id": self.response_id,
524
+ "item_id": self.message_id,
525
+ "output_index": self.message_output_index,
526
+ "content_index": self.content_index,
527
+ "text": text,
528
+ },
529
+ ),
530
+ self._event(
531
+ "response.content_part.done",
532
+ {
533
+ "type": "response.content_part.done",
534
+ "response_id": self.response_id,
535
+ "item_id": self.message_id,
536
+ "output_index": self.message_output_index,
537
+ "content_index": self.content_index,
538
+ "part": {"type": "output_text", "text": text, "annotations": []},
539
+ },
540
+ ),
541
+ self._event(
542
+ "response.output_item.done",
543
+ {
544
+ "type": "response.output_item.done",
545
+ "response_id": self.response_id,
546
+ "output_index": self.message_output_index,
547
+ "item": _build_output_message(
548
+ text, message_id=self.message_id, status="completed"
549
+ ),
550
+ },
551
+ ),
552
+ ]
553
+
554
+ def ensure_tool_item(self, tool_index: int, call_id: str, name: Optional[str]) -> List[str]:
555
+ if tool_index in self.tool_items:
556
+ item = self.tool_items[tool_index]
557
+ if name and not item.get("name"):
558
+ item["name"] = name
559
+ return []
560
+ output_index = self._alloc_output_index()
561
+ item_id = _new_function_call_id()
562
+ self.tool_items[tool_index] = {
563
+ "item_id": item_id,
564
+ "output_index": output_index,
565
+ "call_id": call_id,
566
+ "name": name,
567
+ "arguments": "",
568
+ }
569
+ tool_item = _build_output_tool_call(
570
+ {"id": call_id, "function": {"name": name, "arguments": ""}},
571
+ item_id=item_id,
572
+ status="in_progress",
573
+ )
574
+ return [
575
+ self._event(
576
+ "response.output_item.added",
577
+ {
578
+ "type": "response.output_item.added",
579
+ "response_id": self.response_id,
580
+ "output_index": output_index,
581
+ "item": tool_item,
582
+ },
583
+ )
584
+ ]
585
+
586
+ def tool_arguments_delta_event(self, tool_index: int, delta: str) -> Optional[str]:
587
+ if not delta:
588
+ return None
589
+ item = self.tool_items.get(tool_index)
590
+ if not item:
591
+ return None
592
+ item["arguments"] += delta
593
+ return self._event(
594
+ "response.function_call_arguments.delta",
595
+ {
596
+ "type": "response.function_call_arguments.delta",
597
+ "response_id": self.response_id,
598
+ "item_id": item["item_id"],
599
+ "output_index": item["output_index"],
600
+ "delta": delta,
601
+ },
602
+ )
603
+
604
+ def tool_arguments_done_events(self) -> List[str]:
605
+ events: List[str] = []
606
+ for tool_index, item in sorted(
607
+ self.tool_items.items(), key=lambda kv: kv[1]["output_index"]
608
+ ):
609
+ events.append(
610
+ self._event(
611
+ "response.function_call_arguments.done",
612
+ {
613
+ "type": "response.function_call_arguments.done",
614
+ "response_id": self.response_id,
615
+ "item_id": item["item_id"],
616
+ "output_index": item["output_index"],
617
+ "arguments": item["arguments"],
618
+ },
619
+ )
620
+ )
621
+ tool_item = _build_output_tool_call(
622
+ {
623
+ "id": item["call_id"],
624
+ "function": {"name": item.get("name"), "arguments": item["arguments"]},
625
+ },
626
+ item_id=item["item_id"],
627
+ status="completed",
628
+ )
629
+ events.append(
630
+ self._event(
631
+ "response.output_item.done",
632
+ {
633
+ "type": "response.output_item.done",
634
+ "response_id": self.response_id,
635
+ "output_index": item["output_index"],
636
+ "item": tool_item,
637
+ },
638
+ )
639
+ )
640
+ return events
641
+
642
+ def record_tool_call(self, tool_index: int, call_id: str, name: Optional[str], arguments_delta: str) -> None:
643
+ tool_call = self.tool_calls_by_index.get(tool_index)
644
+ if not tool_call:
645
+ tool_call = {
646
+ "id": call_id or _new_tool_call_id(),
647
+ "type": "function",
648
+ "function": {"name": name, "arguments": ""},
649
+ }
650
+ self.tool_calls_by_index[tool_index] = tool_call
651
+ if name and not tool_call["function"].get("name"):
652
+ tool_call["function"]["name"] = name
653
+ if arguments_delta:
654
+ tool_call["function"]["arguments"] += arguments_delta
655
+
656
+ def completed_event(self, usage: Optional[Dict[str, Any]] = None) -> str:
657
+ response = self._response_payload(
658
+ status="completed",
659
+ output_text="".join(self.output_text_parts) if self.message_started else None,
660
+ usage=usage
661
+ or {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
662
+ )
663
+ payload = {"type": "response.completed", "response": response}
664
+ return self._event("response.completed", payload)
665
+
666
+
667
+ class ResponsesService:
668
+ @staticmethod
669
+ async def create(
670
+ *,
671
+ model: str,
672
+ input_value: Any,
673
+ instructions: Optional[str] = None,
674
+ stream: bool = False,
675
+ temperature: Optional[float] = None,
676
+ top_p: Optional[float] = None,
677
+ tools: Optional[List[Dict[str, Any]]] = None,
678
+ tool_choice: Any = None,
679
+ parallel_tool_calls: Optional[bool] = None,
680
+ reasoning_effort: Optional[str] = None,
681
+ max_output_tokens: Optional[int] = None,
682
+ metadata: Optional[Dict[str, Any]] = None,
683
+ user: Optional[str] = None,
684
+ store: Optional[bool] = None,
685
+ previous_response_id: Optional[str] = None,
686
+ truncation: Optional[str] = None,
687
+ ) -> Any:
688
+ messages = _coerce_input_to_messages(input_value)
689
+ if instructions:
690
+ messages = [{"role": "system", "content": instructions}] + messages
691
+
692
+ if not messages:
693
+ raise ValueError("input is required")
694
+
695
+ normalized_tools = _normalize_tools_for_chat(tools)
696
+ normalized_tool_choice = _normalize_tool_choice(tool_choice)
697
+
698
+ chat_kwargs: Dict[str, Any] = {
699
+ "model": model,
700
+ "messages": messages,
701
+ "stream": stream,
702
+ }
703
+ if temperature is not None:
704
+ chat_kwargs["temperature"] = temperature
705
+ if top_p is not None:
706
+ chat_kwargs["top_p"] = top_p
707
+ if normalized_tools is not None:
708
+ chat_kwargs["tools"] = normalized_tools
709
+ if normalized_tool_choice is not None:
710
+ chat_kwargs["tool_choice"] = normalized_tool_choice
711
+ if parallel_tool_calls is not None:
712
+ chat_kwargs["parallel_tool_calls"] = parallel_tool_calls
713
+ if reasoning_effort is not None:
714
+ chat_kwargs["reasoning_effort"] = reasoning_effort
715
+
716
+ result = await ChatService.completions(**chat_kwargs)
717
+
718
+ if not stream:
719
+ if not isinstance(result, dict):
720
+ raise ValueError("Unexpected stream response for non-stream request")
721
+ choice = (result.get("choices") or [{}])[0]
722
+ message = choice.get("message") or {}
723
+ content = message.get("content") or ""
724
+ tool_calls = message.get("tool_calls")
725
+ return _build_response_object(
726
+ model=model,
727
+ output_text=content,
728
+ tool_calls=tool_calls,
729
+ usage=result.get("usage")
730
+ or {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
731
+ status="completed",
732
+ instructions=instructions,
733
+ max_output_tokens=max_output_tokens,
734
+ parallel_tool_calls=parallel_tool_calls,
735
+ previous_response_id=previous_response_id,
736
+ reasoning_effort=reasoning_effort,
737
+ store=store,
738
+ temperature=temperature,
739
+ tool_choice=tool_choice,
740
+ tools=tools,
741
+ top_p=top_p,
742
+ truncation=truncation,
743
+ user=user,
744
+ metadata=metadata,
745
+ )
746
+
747
+ if not hasattr(result, "__aiter__"):
748
+ raise ValueError("Unexpected non-stream response for stream request")
749
+
750
+ created_at = _now_ts()
751
+ response_id = _new_response_id()
752
+ adapter = ResponseStreamAdapter(
753
+ model=model,
754
+ response_id=response_id,
755
+ created_at=created_at,
756
+ instructions=instructions,
757
+ max_output_tokens=max_output_tokens,
758
+ parallel_tool_calls=parallel_tool_calls,
759
+ previous_response_id=previous_response_id,
760
+ reasoning_effort=reasoning_effort,
761
+ store=store,
762
+ temperature=temperature,
763
+ tool_choice=tool_choice,
764
+ tools=tools,
765
+ top_p=top_p,
766
+ truncation=truncation,
767
+ user=user,
768
+ metadata=metadata,
769
+ )
770
+
771
+ async def _stream() -> AsyncGenerator[str, None]:
772
+ yield adapter.created_event()
773
+ yield adapter.in_progress_event()
774
+ async for chunk in result:
775
+ line = proc_base._normalize_line(chunk)
776
+ if not line:
777
+ continue
778
+ try:
779
+ data = orjson.loads(line)
780
+ except orjson.JSONDecodeError:
781
+ continue
782
+
783
+ if data.get("object") == "chat.completion.chunk":
784
+ delta = (data.get("choices") or [{}])[0].get("delta") or {}
785
+ if "content" in delta and delta["content"]:
786
+ for event in adapter.ensure_message_started():
787
+ yield event
788
+ adapter.output_text_parts.append(delta["content"])
789
+ yield adapter.output_delta_event(delta["content"])
790
+ tool_calls = delta.get("tool_calls")
791
+ if isinstance(tool_calls, list):
792
+ for tool in tool_calls:
793
+ if not isinstance(tool, dict):
794
+ continue
795
+ tool_index = tool.get("index", 0)
796
+ call_id = tool.get("id") or _new_tool_call_id()
797
+ fn = tool.get("function") or {}
798
+ name = fn.get("name")
799
+ args_delta = fn.get("arguments") or ""
800
+ adapter.record_tool_call(
801
+ tool_index, call_id, name, args_delta
802
+ )
803
+ for event in adapter.ensure_tool_item(
804
+ tool_index, call_id, name
805
+ ):
806
+ yield event
807
+ delta_event = adapter.tool_arguments_delta_event(
808
+ tool_index, args_delta
809
+ )
810
+ if delta_event:
811
+ yield delta_event
812
+
813
+ full_text = "".join(adapter.output_text_parts)
814
+ if full_text and adapter.message_started:
815
+ for event in adapter.output_done_events(full_text):
816
+ yield event
817
+ for event in adapter.tool_arguments_done_events():
818
+ yield event
819
+ yield adapter.completed_event()
820
+
821
+ return _stream()
822
+
823
+
824
+ __all__ = ["ResponsesService"]
app/services/grok/services/video.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok video generation service.
3
+ """
4
+
5
+ import asyncio
6
+ import uuid
7
+ import re
8
+ from typing import Any, AsyncGenerator, AsyncIterable, Optional
9
+
10
+ import orjson
11
+ from curl_cffi.requests.errors import RequestsError
12
+
13
+ from app.core.logger import logger
14
+ from app.core.config import get_config
15
+ from app.core.exceptions import (
16
+ UpstreamException,
17
+ AppException,
18
+ ValidationException,
19
+ ErrorType,
20
+ StreamIdleTimeoutError,
21
+ )
22
+ from app.services.grok.services.model import ModelService
23
+ from app.services.token import get_token_manager, EffortType
24
+ from app.services.grok.utils.stream import wrap_stream_with_usage
25
+ from app.services.grok.utils.process import (
26
+ BaseProcessor,
27
+ _with_idle_timeout,
28
+ _normalize_line,
29
+ _is_http2_error,
30
+ )
31
+ from app.services.grok.utils.retry import rate_limited
32
+ from app.services.reverse.app_chat import AppChatReverse
33
+ from app.services.reverse.media_post import MediaPostReverse
34
+ from app.services.reverse.video_upscale import VideoUpscaleReverse
35
+ from app.services.reverse.utils.session import ResettableSession
36
+ from app.services.token.manager import BASIC_POOL_NAME
37
+
38
+ _VIDEO_SEMAPHORE = None
39
+ _VIDEO_SEM_VALUE = 0
40
+
41
+ def _get_video_semaphore() -> asyncio.Semaphore:
42
+ """Reverse 接口并发控制(video 服务)。"""
43
+ global _VIDEO_SEMAPHORE, _VIDEO_SEM_VALUE
44
+ value = max(1, int(get_config("video.concurrent")))
45
+ if value != _VIDEO_SEM_VALUE:
46
+ _VIDEO_SEM_VALUE = value
47
+ _VIDEO_SEMAPHORE = asyncio.Semaphore(value)
48
+ return _VIDEO_SEMAPHORE
49
+
50
+
51
+ def _new_session() -> ResettableSession:
52
+ browser = get_config("proxy.browser")
53
+ if browser:
54
+ return ResettableSession(impersonate=browser)
55
+ return ResettableSession()
56
+
57
+
58
+ class VideoService:
59
+ """Video generation service."""
60
+
61
+ def __init__(self):
62
+ self.timeout = None
63
+
64
+ async def create_post(
65
+ self,
66
+ token: str,
67
+ prompt: str,
68
+ media_type: str = "MEDIA_POST_TYPE_VIDEO",
69
+ media_url: str = None,
70
+ ) -> str:
71
+ """Create media post and return post ID."""
72
+ try:
73
+ if media_type == "MEDIA_POST_TYPE_IMAGE" and not media_url:
74
+ raise ValidationException("media_url is required for image posts")
75
+
76
+ prompt_value = prompt if media_type == "MEDIA_POST_TYPE_VIDEO" else ""
77
+ media_value = media_url or ""
78
+
79
+ async with _new_session() as session:
80
+ async with _get_video_semaphore():
81
+ response = await MediaPostReverse.request(
82
+ session,
83
+ token,
84
+ media_type,
85
+ media_value,
86
+ prompt=prompt_value,
87
+ )
88
+
89
+ post_id = response.json().get("post", {}).get("id", "")
90
+ if not post_id:
91
+ raise UpstreamException("No post ID in response")
92
+
93
+ logger.info(f"Media post created: {post_id} (type={media_type})")
94
+ return post_id
95
+
96
+ except AppException:
97
+ raise
98
+ except Exception as e:
99
+ logger.error(f"Create post error: {e}")
100
+ raise UpstreamException(f"Create post error: {str(e)}")
101
+
102
+ async def create_image_post(self, token: str, image_url: str) -> str:
103
+ """Create image post and return post ID."""
104
+ return await self.create_post(
105
+ token, prompt="", media_type="MEDIA_POST_TYPE_IMAGE", media_url=image_url
106
+ )
107
+
108
+ async def generate(
109
+ self,
110
+ token: str,
111
+ prompt: str,
112
+ aspect_ratio: str = "3:2",
113
+ video_length: int = 6,
114
+ resolution_name: str = "480p",
115
+ preset: str = "normal",
116
+ ) -> AsyncGenerator[bytes, None]:
117
+ """Generate video."""
118
+ logger.info(
119
+ f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}"
120
+ )
121
+ post_id = await self.create_post(token, prompt)
122
+ mode_map = {
123
+ "fun": "--mode=extremely-crazy",
124
+ "normal": "--mode=normal",
125
+ "spicy": "--mode=extremely-spicy-or-crazy",
126
+ }
127
+ mode_flag = mode_map.get(preset, "--mode=custom")
128
+ message = f"{prompt} {mode_flag}"
129
+ model_config_override = {
130
+ "modelMap": {
131
+ "videoGenModelConfig": {
132
+ "aspectRatio": aspect_ratio,
133
+ "parentPostId": post_id,
134
+ "resolutionName": resolution_name,
135
+ "videoLength": video_length,
136
+ }
137
+ }
138
+ }
139
+
140
+ async def _stream():
141
+ session = _new_session()
142
+ try:
143
+ async with _get_video_semaphore():
144
+ stream_response = await AppChatReverse.request(
145
+ session,
146
+ token,
147
+ message=message,
148
+ model="grok-3",
149
+ tool_overrides={"videoGen": True},
150
+ model_config_override=model_config_override,
151
+ )
152
+ logger.info(f"Video generation started: post_id={post_id}")
153
+ async for line in stream_response:
154
+ yield line
155
+ except Exception as e:
156
+ try:
157
+ await session.close()
158
+ except Exception:
159
+ pass
160
+ logger.error(f"Video generation error: {e}")
161
+ if isinstance(e, AppException):
162
+ raise
163
+ raise UpstreamException(f"Video generation error: {str(e)}")
164
+
165
+ return _stream()
166
+
167
+ async def generate_from_image(
168
+ self,
169
+ token: str,
170
+ prompt: str,
171
+ image_url: str,
172
+ aspect_ratio: str = "3:2",
173
+ video_length: int = 6,
174
+ resolution: str = "480p",
175
+ preset: str = "normal",
176
+ ) -> AsyncGenerator[bytes, None]:
177
+ """Generate video from image."""
178
+ logger.info(
179
+ f"Image to video: prompt='{prompt[:50]}...', image={image_url[:80]}"
180
+ )
181
+ post_id = await self.create_image_post(token, image_url)
182
+ mode_map = {
183
+ "fun": "--mode=extremely-crazy",
184
+ "normal": "--mode=normal",
185
+ "spicy": "--mode=extremely-spicy-or-crazy",
186
+ }
187
+ mode_flag = mode_map.get(preset, "--mode=custom")
188
+ message = f"{prompt} {mode_flag}"
189
+ model_config_override = {
190
+ "modelMap": {
191
+ "videoGenModelConfig": {
192
+ "aspectRatio": aspect_ratio,
193
+ "parentPostId": post_id,
194
+ "resolutionName": resolution,
195
+ "videoLength": video_length,
196
+ }
197
+ }
198
+ }
199
+
200
+ async def _stream():
201
+ session = _new_session()
202
+ try:
203
+ async with _get_video_semaphore():
204
+ stream_response = await AppChatReverse.request(
205
+ session,
206
+ token,
207
+ message=message,
208
+ model="grok-3",
209
+ tool_overrides={"videoGen": True},
210
+ model_config_override=model_config_override,
211
+ )
212
+ logger.info(f"Video generation started: post_id={post_id}")
213
+ async for line in stream_response:
214
+ yield line
215
+ except Exception as e:
216
+ try:
217
+ await session.close()
218
+ except Exception:
219
+ pass
220
+ logger.error(f"Video generation error: {e}")
221
+ if isinstance(e, AppException):
222
+ raise
223
+ raise UpstreamException(f"Video generation error: {str(e)}")
224
+
225
+ return _stream()
226
+
227
+ @staticmethod
228
+ async def completions(
229
+ model: str,
230
+ messages: list,
231
+ stream: bool = None,
232
+ reasoning_effort: str | None = None,
233
+ aspect_ratio: str = "3:2",
234
+ video_length: int = 6,
235
+ resolution: str = "480p",
236
+ preset: str = "normal",
237
+ ):
238
+ """Video generation entrypoint."""
239
+ # Get token via intelligent routing.
240
+ token_mgr = await get_token_manager()
241
+ await token_mgr.reload_if_stale()
242
+
243
+ max_token_retries = int(get_config("retry.max_retry"))
244
+ last_error: Exception | None = None
245
+
246
+ if reasoning_effort is None:
247
+ show_think = get_config("app.thinking")
248
+ else:
249
+ show_think = reasoning_effort != "none"
250
+ is_stream = stream if stream is not None else get_config("app.stream")
251
+
252
+ # Extract content.
253
+ from app.services.grok.services.chat import MessageExtractor
254
+ from app.services.grok.utils.upload import UploadService
255
+
256
+ prompt, file_attachments, image_attachments = MessageExtractor.extract(messages)
257
+
258
+ for attempt in range(max_token_retries):
259
+ # Select token based on video requirements and pool candidates.
260
+ pool_candidates = ModelService.pool_candidates_for_model(model)
261
+ token_info = token_mgr.get_token_for_video(
262
+ resolution=resolution,
263
+ video_length=video_length,
264
+ pool_candidates=pool_candidates,
265
+ )
266
+
267
+ if not token_info:
268
+ if last_error:
269
+ raise last_error
270
+ raise AppException(
271
+ message="No available tokens. Please try again later.",
272
+ error_type=ErrorType.RATE_LIMIT.value,
273
+ code="rate_limit_exceeded",
274
+ status_code=429,
275
+ )
276
+
277
+ # Extract token string from TokenInfo.
278
+ token = token_info.token
279
+ if token.startswith("sso="):
280
+ token = token[4:]
281
+ pool_name = token_mgr.get_pool_name_for_token(token)
282
+ should_upscale = resolution == "720p" and pool_name == BASIC_POOL_NAME
283
+
284
+ try:
285
+ # Handle image attachments.
286
+ image_url = None
287
+ if image_attachments:
288
+ upload_service = UploadService()
289
+ try:
290
+ if len(image_attachments) > 1:
291
+ logger.info(
292
+ "Video generation supports a single reference image; using the first one."
293
+ )
294
+ attach_data = image_attachments[0]
295
+ _, file_uri = await upload_service.upload_file(
296
+ attach_data, token
297
+ )
298
+ image_url = f"https://assets.grok.com/{file_uri}"
299
+ logger.info(f"Image uploaded for video: {image_url}")
300
+ finally:
301
+ await upload_service.close()
302
+
303
+ # Generate video.
304
+ service = VideoService()
305
+ if image_url:
306
+ response = await service.generate_from_image(
307
+ token,
308
+ prompt,
309
+ image_url,
310
+ aspect_ratio,
311
+ video_length,
312
+ resolution,
313
+ preset,
314
+ )
315
+ else:
316
+ response = await service.generate(
317
+ token,
318
+ prompt,
319
+ aspect_ratio,
320
+ video_length,
321
+ resolution,
322
+ preset,
323
+ )
324
+
325
+ # Process response.
326
+ if is_stream:
327
+ processor = VideoStreamProcessor(
328
+ model,
329
+ token,
330
+ show_think,
331
+ upscale_on_finish=should_upscale,
332
+ )
333
+ return wrap_stream_with_usage(
334
+ processor.process(response), token_mgr, token, model
335
+ )
336
+
337
+ result = await VideoCollectProcessor(
338
+ model, token, upscale_on_finish=should_upscale
339
+ ).process(response)
340
+ try:
341
+ model_info = ModelService.get(model)
342
+ effort = (
343
+ EffortType.HIGH
344
+ if (model_info and model_info.cost.value == "high")
345
+ else EffortType.LOW
346
+ )
347
+ await token_mgr.consume(token, effort)
348
+ logger.debug(
349
+ f"Video completed, recorded usage (effort={effort.value})"
350
+ )
351
+ except Exception as e:
352
+ logger.warning(f"Failed to record video usage: {e}")
353
+ return result
354
+
355
+ except UpstreamException as e:
356
+ last_error = e
357
+ if rate_limited(e):
358
+ await token_mgr.mark_rate_limited(token)
359
+ logger.warning(
360
+ f"Token {token[:10]}... rate limited (429), "
361
+ f"trying next token (attempt {attempt + 1}/{max_token_retries})"
362
+ )
363
+ continue
364
+ raise
365
+
366
+ if last_error:
367
+ raise last_error
368
+ raise AppException(
369
+ message="No available tokens. Please try again later.",
370
+ error_type=ErrorType.RATE_LIMIT.value,
371
+ code="rate_limit_exceeded",
372
+ status_code=429,
373
+ )
374
+
375
+
376
+ class VideoStreamProcessor(BaseProcessor):
377
+ """Video stream response processor."""
378
+
379
+ def __init__(
380
+ self,
381
+ model: str,
382
+ token: str = "",
383
+ show_think: bool = None,
384
+ upscale_on_finish: bool = False,
385
+ ):
386
+ super().__init__(model, token)
387
+ self.response_id: Optional[str] = None
388
+ self.think_opened: bool = False
389
+ self.role_sent: bool = False
390
+
391
+ self.show_think = bool(show_think)
392
+ self.upscale_on_finish = bool(upscale_on_finish)
393
+
394
+ @staticmethod
395
+ def _extract_video_id(video_url: str) -> str:
396
+ if not video_url:
397
+ return ""
398
+ match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url)
399
+ if match:
400
+ return match.group(1)
401
+ match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url)
402
+ if match:
403
+ return match.group(1)
404
+ return ""
405
+
406
+ async def _upscale_video_url(self, video_url: str) -> str:
407
+ if not video_url or not self.upscale_on_finish:
408
+ return video_url
409
+ video_id = self._extract_video_id(video_url)
410
+ if not video_id:
411
+ logger.warning("Video upscale skipped: unable to extract video id")
412
+ return video_url
413
+ try:
414
+ async with _new_session() as session:
415
+ response = await VideoUpscaleReverse.request(
416
+ session, self.token, video_id
417
+ )
418
+ payload = response.json() if response is not None else {}
419
+ hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None
420
+ if hd_url:
421
+ logger.info(f"Video upscale completed: {hd_url}")
422
+ return hd_url
423
+ except Exception as e:
424
+ logger.warning(f"Video upscale failed: {e}")
425
+ return video_url
426
+
427
+ def _sse(self, content: str = "", role: str = None, finish: str = None) -> str:
428
+ """Build SSE response."""
429
+ delta = {}
430
+ if role:
431
+ delta["role"] = role
432
+ delta["content"] = ""
433
+ elif content:
434
+ delta["content"] = content
435
+
436
+ chunk = {
437
+ "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
438
+ "object": "chat.completion.chunk",
439
+ "created": self.created,
440
+ "model": self.model,
441
+ "choices": [
442
+ {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}
443
+ ],
444
+ }
445
+ return f"data: {orjson.dumps(chunk).decode()}\n\n"
446
+
447
+ async def process(
448
+ self, response: AsyncIterable[bytes]
449
+ ) -> AsyncGenerator[str, None]:
450
+ """Process video stream response."""
451
+ idle_timeout = get_config("video.stream_timeout")
452
+
453
+ try:
454
+ async for line in _with_idle_timeout(response, idle_timeout, self.model):
455
+ line = _normalize_line(line)
456
+ if not line:
457
+ continue
458
+ try:
459
+ data = orjson.loads(line)
460
+ except orjson.JSONDecodeError:
461
+ continue
462
+
463
+ resp = data.get("result", {}).get("response", {})
464
+ is_thinking = bool(resp.get("isThinking"))
465
+
466
+ if rid := resp.get("responseId"):
467
+ self.response_id = rid
468
+
469
+ if not self.role_sent:
470
+ yield self._sse(role="assistant")
471
+ self.role_sent = True
472
+
473
+ if token := resp.get("token"):
474
+ if is_thinking:
475
+ if not self.show_think:
476
+ continue
477
+ if not self.think_opened:
478
+ yield self._sse("<think>\n")
479
+ self.think_opened = True
480
+ else:
481
+ if self.think_opened:
482
+ yield self._sse("\n</think>\n")
483
+ self.think_opened = False
484
+ yield self._sse(token)
485
+ continue
486
+
487
+ if video_resp := resp.get("streamingVideoGenerationResponse"):
488
+ progress = video_resp.get("progress", 0)
489
+
490
+ if is_thinking:
491
+ if not self.show_think:
492
+ continue
493
+ if not self.think_opened:
494
+ yield self._sse("<think>\n")
495
+ self.think_opened = True
496
+ else:
497
+ if self.think_opened:
498
+ yield self._sse("\n</think>\n")
499
+ self.think_opened = False
500
+ if self.show_think:
501
+ yield self._sse(f"正在生成视频中,当前进度{progress}%\n")
502
+
503
+ if progress == 100:
504
+ video_url = video_resp.get("videoUrl", "")
505
+ thumbnail_url = video_resp.get("thumbnailImageUrl", "")
506
+
507
+ if self.think_opened:
508
+ yield self._sse("\n</think>\n")
509
+ self.think_opened = False
510
+
511
+ if video_url:
512
+ if self.upscale_on_finish:
513
+ yield self._sse("正在对视频进行超分辨率\n")
514
+ video_url = await self._upscale_video_url(video_url)
515
+ dl_service = self._get_dl()
516
+ rendered = await dl_service.render_video(
517
+ video_url, self.token, thumbnail_url
518
+ )
519
+ yield self._sse(rendered)
520
+
521
+ logger.info(f"Video generated: {video_url}")
522
+ continue
523
+
524
+ if self.think_opened:
525
+ yield self._sse("</think>\n")
526
+ yield self._sse(finish="stop")
527
+ yield "data: [DONE]\n\n"
528
+ except asyncio.CancelledError:
529
+ logger.debug(
530
+ "Video stream cancelled by client", extra={"model": self.model}
531
+ )
532
+ except StreamIdleTimeoutError as e:
533
+ raise UpstreamException(
534
+ message=f"Video stream idle timeout after {e.idle_seconds}s",
535
+ status_code=504,
536
+ details={
537
+ "error": str(e),
538
+ "type": "stream_idle_timeout",
539
+ "idle_seconds": e.idle_seconds,
540
+ },
541
+ )
542
+ except RequestsError as e:
543
+ if _is_http2_error(e):
544
+ logger.warning(
545
+ f"HTTP/2 stream error in video: {e}", extra={"model": self.model}
546
+ )
547
+ raise UpstreamException(
548
+ message="Upstream connection closed unexpectedly",
549
+ status_code=502,
550
+ details={"error": str(e), "type": "http2_stream_error"},
551
+ )
552
+ logger.error(
553
+ f"Video stream request error: {e}", extra={"model": self.model}
554
+ )
555
+ raise UpstreamException(
556
+ message=f"Upstream request failed: {e}",
557
+ status_code=502,
558
+ details={"error": str(e)},
559
+ )
560
+ except Exception as e:
561
+ logger.error(
562
+ f"Video stream processing error: {e}",
563
+ extra={"model": self.model, "error_type": type(e).__name__},
564
+ )
565
+ finally:
566
+ await self.close()
567
+
568
+
569
+ class VideoCollectProcessor(BaseProcessor):
570
+ """Video non-stream response processor."""
571
+
572
+ def __init__(self, model: str, token: str = "", upscale_on_finish: bool = False):
573
+ super().__init__(model, token)
574
+ self.upscale_on_finish = bool(upscale_on_finish)
575
+
576
+ @staticmethod
577
+ def _extract_video_id(video_url: str) -> str:
578
+ if not video_url:
579
+ return ""
580
+ match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url)
581
+ if match:
582
+ return match.group(1)
583
+ match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url)
584
+ if match:
585
+ return match.group(1)
586
+ return ""
587
+
588
+ async def _upscale_video_url(self, video_url: str) -> str:
589
+ if not video_url or not self.upscale_on_finish:
590
+ return video_url
591
+ video_id = self._extract_video_id(video_url)
592
+ if not video_id:
593
+ logger.warning("Video upscale skipped: unable to extract video id")
594
+ return video_url
595
+ try:
596
+ async with _new_session() as session:
597
+ response = await VideoUpscaleReverse.request(
598
+ session, self.token, video_id
599
+ )
600
+ payload = response.json() if response is not None else {}
601
+ hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None
602
+ if hd_url:
603
+ logger.info(f"Video upscale completed: {hd_url}")
604
+ return hd_url
605
+ except Exception as e:
606
+ logger.warning(f"Video upscale failed: {e}")
607
+ return video_url
608
+
609
+ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
610
+ """Process and collect video response."""
611
+ response_id = ""
612
+ content = ""
613
+ idle_timeout = get_config("video.stream_timeout")
614
+
615
+ try:
616
+ async for line in _with_idle_timeout(response, idle_timeout, self.model):
617
+ line = _normalize_line(line)
618
+ if not line:
619
+ continue
620
+ try:
621
+ data = orjson.loads(line)
622
+ except orjson.JSONDecodeError:
623
+ continue
624
+
625
+ resp = data.get("result", {}).get("response", {})
626
+
627
+ if video_resp := resp.get("streamingVideoGenerationResponse"):
628
+ if video_resp.get("progress") == 100:
629
+ response_id = resp.get("responseId", "")
630
+ video_url = video_resp.get("videoUrl", "")
631
+ thumbnail_url = video_resp.get("thumbnailImageUrl", "")
632
+
633
+ if video_url:
634
+ if self.upscale_on_finish:
635
+ video_url = await self._upscale_video_url(video_url)
636
+ dl_service = self._get_dl()
637
+ content = await dl_service.render_video(
638
+ video_url, self.token, thumbnail_url
639
+ )
640
+ logger.info(f"Video generated: {video_url}")
641
+
642
+ except asyncio.CancelledError:
643
+ logger.debug(
644
+ "Video collect cancelled by client", extra={"model": self.model}
645
+ )
646
+ except StreamIdleTimeoutError as e:
647
+ logger.warning(
648
+ f"Video collect idle timeout: {e}", extra={"model": self.model}
649
+ )
650
+ except RequestsError as e:
651
+ if _is_http2_error(e):
652
+ logger.warning(
653
+ f"HTTP/2 stream error in video collect: {e}",
654
+ extra={"model": self.model},
655
+ )
656
+ else:
657
+ logger.error(
658
+ f"Video collect request error: {e}", extra={"model": self.model}
659
+ )
660
+ except Exception as e:
661
+ logger.error(
662
+ f"Video collect processing error: {e}",
663
+ extra={"model": self.model, "error_type": type(e).__name__},
664
+ )
665
+ finally:
666
+ await self.close()
667
+
668
+ return {
669
+ "id": response_id,
670
+ "object": "chat.completion",
671
+ "created": self.created,
672
+ "model": self.model,
673
+ "choices": [
674
+ {
675
+ "index": 0,
676
+ "message": {
677
+ "role": "assistant",
678
+ "content": content,
679
+ "refusal": None,
680
+ },
681
+ "finish_reason": "stop",
682
+ }
683
+ ],
684
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
685
+ }
686
+
687
+
688
+ __all__ = ["VideoService"]
app/services/grok/services/voice.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok Voice Mode Service
3
+ """
4
+
5
+ from typing import Any, Dict
6
+
7
+ from app.core.config import get_config
8
+ from app.services.reverse.ws_livekit import LivekitTokenReverse
9
+ from app.services.reverse.utils.session import ResettableSession
10
+
11
+
12
+ class VoiceService:
13
+ """Voice Mode Service (LiveKit)"""
14
+
15
+ async def get_token(
16
+ self,
17
+ token: str,
18
+ voice: str = "ara",
19
+ personality: str = "assistant",
20
+ speed: float = 1.0,
21
+ ) -> Dict[str, Any]:
22
+ browser = get_config("proxy.browser")
23
+ async with ResettableSession(impersonate=browser) as session:
24
+ response = await LivekitTokenReverse.request(
25
+ session,
26
+ token=token,
27
+ voice=voice,
28
+ personality=personality,
29
+ speed=speed,
30
+ )
31
+ return response.json()
app/services/grok/utils/cache.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Local cache utilities.
3
+ """
4
+
5
+ from typing import Any, Dict
6
+
7
+ from app.core.storage import DATA_DIR
8
+
9
+ IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
10
+ VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"}
11
+
12
+
13
+ class CacheService:
14
+ """Local cache service."""
15
+
16
+ def __init__(self):
17
+ base_dir = DATA_DIR / "tmp"
18
+ self.image_dir = base_dir / "image"
19
+ self.video_dir = base_dir / "video"
20
+ self.image_dir.mkdir(parents=True, exist_ok=True)
21
+ self.video_dir.mkdir(parents=True, exist_ok=True)
22
+
23
+ def _cache_dir(self, media_type: str):
24
+ return self.image_dir if media_type == "image" else self.video_dir
25
+
26
+ def _allowed_exts(self, media_type: str):
27
+ return IMAGE_EXTS if media_type == "image" else VIDEO_EXTS
28
+
29
+ def get_stats(self, media_type: str = "image") -> Dict[str, Any]:
30
+ cache_dir = self._cache_dir(media_type)
31
+ if not cache_dir.exists():
32
+ return {"count": 0, "size_mb": 0.0}
33
+
34
+ allowed = self._allowed_exts(media_type)
35
+ files = [
36
+ f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed
37
+ ]
38
+ total_size = sum(f.stat().st_size for f in files)
39
+ return {"count": len(files), "size_mb": round(total_size / 1024 / 1024, 2)}
40
+
41
+ def list_files(
42
+ self, media_type: str = "image", page: int = 1, page_size: int = 1000
43
+ ) -> Dict[str, Any]:
44
+ cache_dir = self._cache_dir(media_type)
45
+ if not cache_dir.exists():
46
+ return {"total": 0, "page": page, "page_size": page_size, "items": []}
47
+
48
+ allowed = self._allowed_exts(media_type)
49
+ files = [
50
+ f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed
51
+ ]
52
+
53
+ items = []
54
+ for f in files:
55
+ try:
56
+ stat = f.stat()
57
+ items.append(
58
+ {
59
+ "name": f.name,
60
+ "size_bytes": stat.st_size,
61
+ "mtime_ms": int(stat.st_mtime * 1000),
62
+ }
63
+ )
64
+ except Exception:
65
+ continue
66
+
67
+ items.sort(key=lambda x: x["mtime_ms"], reverse=True)
68
+
69
+ total = len(items)
70
+ start = max(0, (page - 1) * page_size)
71
+ paged = items[start : start + page_size]
72
+
73
+ for item in paged:
74
+ item["view_url"] = f"/v1/files/{media_type}/{item['name']}"
75
+
76
+ return {"total": total, "page": page, "page_size": page_size, "items": paged}
77
+
78
+ def delete_file(self, media_type: str, name: str) -> Dict[str, Any]:
79
+ cache_dir = self._cache_dir(media_type)
80
+ file_path = cache_dir / name.replace("/", "-")
81
+
82
+ if file_path.exists():
83
+ try:
84
+ file_path.unlink()
85
+ return {"deleted": True}
86
+ except Exception:
87
+ pass
88
+ return {"deleted": False}
89
+
90
+ def clear(self, media_type: str = "image") -> Dict[str, Any]:
91
+ cache_dir = self._cache_dir(media_type)
92
+ if not cache_dir.exists():
93
+ return {"count": 0, "size_mb": 0.0}
94
+
95
+ files = list(cache_dir.glob("*"))
96
+ total_size = sum(f.stat().st_size for f in files if f.is_file())
97
+ count = 0
98
+
99
+ for f in files:
100
+ if f.is_file():
101
+ try:
102
+ f.unlink()
103
+ count += 1
104
+ except Exception:
105
+ pass
106
+
107
+ return {"count": count, "size_mb": round(total_size / 1024 / 1024, 2)}
108
+
109
+
110
+ __all__ = ["CacheService"]
app/services/grok/utils/download.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download service.
3
+
4
+ Download service for assets.grok.com.
5
+ """
6
+
7
+ import asyncio
8
+ import base64
9
+ import hashlib
10
+ import os
11
+ from pathlib import Path
12
+ from typing import List, Optional, Tuple
13
+ from urllib.parse import urlparse
14
+
15
+ import aiofiles
16
+
17
+ from app.core.logger import logger
18
+ from app.core.storage import DATA_DIR
19
+ from app.core.config import get_config
20
+ from app.core.exceptions import AppException
21
+ from app.services.reverse.assets_download import AssetsDownloadReverse
22
+ from app.services.reverse.utils.session import ResettableSession
23
+ from app.services.grok.utils.locks import _get_download_semaphore, _file_lock
24
+
25
+
26
+ class DownloadService:
27
+ """Assets download service."""
28
+
29
+ def __init__(self):
30
+ self._session: Optional[ResettableSession] = None
31
+ base_dir = DATA_DIR / "tmp"
32
+ self.image_dir = base_dir / "image"
33
+ self.video_dir = base_dir / "video"
34
+ self.image_dir.mkdir(parents=True, exist_ok=True)
35
+ self.video_dir.mkdir(parents=True, exist_ok=True)
36
+ self._cleanup_running = False
37
+
38
+ async def create(self) -> ResettableSession:
39
+ """Create or reuse a session."""
40
+ if self._session is None:
41
+ browser = get_config("proxy.browser")
42
+ if browser:
43
+ self._session = ResettableSession(impersonate=browser)
44
+ else:
45
+ self._session = ResettableSession()
46
+ return self._session
47
+
48
+ async def close(self):
49
+ """Close the session."""
50
+ if self._session:
51
+ await self._session.close()
52
+ self._session = None
53
+
54
+ async def resolve_url(
55
+ self, path_or_url: str, token: str, media_type: str = "image"
56
+ ) -> str:
57
+ asset_url = path_or_url
58
+ path = path_or_url
59
+ if path_or_url.startswith("http"):
60
+ parsed = urlparse(path_or_url)
61
+ path = parsed.path or ""
62
+ asset_url = path_or_url
63
+ else:
64
+ if not path_or_url.startswith("/"):
65
+ path_or_url = f"/{path_or_url}"
66
+ path = path_or_url
67
+ asset_url = f"https://assets.grok.com{path_or_url}"
68
+
69
+ app_url = get_config("app.app_url")
70
+ if app_url:
71
+ await self.download_file(asset_url, token, media_type)
72
+ return f"{app_url.rstrip('/')}/v1/files/{media_type}{path}"
73
+ return asset_url
74
+
75
+ async def render_image(
76
+ self, url: str, token: str, image_id: str = "image"
77
+ ) -> str:
78
+ fmt = get_config("app.image_format")
79
+ fmt = fmt.lower() if isinstance(fmt, str) else "url"
80
+ if fmt not in ("base64", "url", "markdown"):
81
+ fmt = "url"
82
+ try:
83
+ if fmt == "base64":
84
+ data_uri = await self.parse_b64(url, token, "image")
85
+ return f"![{image_id}]({data_uri})"
86
+ final_url = await self.resolve_url(url, token, "image")
87
+ return f"![{image_id}]({final_url})"
88
+ except Exception as e:
89
+ logger.warning(f"Image render failed, fallback to URL: {e}")
90
+ final_url = await self.resolve_url(url, token, "image")
91
+ return f"![{image_id}]({final_url})"
92
+
93
+ async def render_video(
94
+ self, video_url: str, token: str, thumbnail_url: str = ""
95
+ ) -> str:
96
+ fmt = get_config("app.video_format")
97
+ fmt = fmt.lower() if isinstance(fmt, str) else "url"
98
+ if fmt not in ("url", "markdown", "html"):
99
+ fmt = "url"
100
+ final_video_url = await self.resolve_url(video_url, token, "video")
101
+ final_thumb_url = ""
102
+ if thumbnail_url:
103
+ final_thumb_url = await self.resolve_url(thumbnail_url, token, "image")
104
+ if fmt == "url":
105
+ return f"{final_video_url}\n"
106
+ if fmt == "markdown":
107
+ return f"[video]({final_video_url})"
108
+ import html
109
+
110
+ safe_video_url = html.escape(final_video_url)
111
+ safe_thumbnail_url = html.escape(final_thumb_url)
112
+ poster_attr = f' poster="{safe_thumbnail_url}"' if safe_thumbnail_url else ""
113
+ return f'''<video id="video" controls="" preload="none"{poster_attr}>
114
+ <source id="mp4" src="{safe_video_url}" type="video/mp4">
115
+ </video>'''
116
+
117
+ async def parse_b64(self, file_path: str, token: str, media_type: str = "image") -> str:
118
+ """Download and return data URI."""
119
+ try:
120
+ if not isinstance(file_path, str) or not file_path.strip():
121
+ raise AppException("Invalid file path", code="invalid_file_path")
122
+ if file_path.startswith("data:"):
123
+ raise AppException("Invalid file path", code="invalid_file_path")
124
+ file_path = self._normalize_path(file_path)
125
+ lock_name = f"dl_b64_{hashlib.sha1(file_path.encode()).hexdigest()[:16]}"
126
+ lock_timeout = max(1, int(get_config("asset.download_timeout")))
127
+ async with _get_download_semaphore():
128
+ async with _file_lock(lock_name, timeout=lock_timeout):
129
+ session = await self.create()
130
+ response = await AssetsDownloadReverse.request(
131
+ session, token, file_path
132
+ )
133
+
134
+ if hasattr(response, "aiter_content"):
135
+ data = bytearray()
136
+ async for chunk in response.aiter_content():
137
+ if chunk:
138
+ data.extend(chunk)
139
+ raw = bytes(data)
140
+ else:
141
+ raw = response.content
142
+
143
+ content_type = response.headers.get(
144
+ "content-type", "application/octet-stream"
145
+ ).split(";")[0]
146
+ data_uri = f"data:{content_type};base64,{base64.b64encode(raw).decode()}"
147
+
148
+ return data_uri
149
+ except Exception as e:
150
+ logger.error(f"Failed to convert {file_path} to base64: {e}")
151
+ raise
152
+
153
+ def _normalize_path(self, file_path: str) -> str:
154
+ """Normalize URL or path to assets path for download."""
155
+ if not isinstance(file_path, str) or not file_path.strip():
156
+ raise AppException("Invalid file path", code="invalid_file_path")
157
+
158
+ value = file_path.strip()
159
+ if value.startswith("data:"):
160
+ raise AppException("Invalid file path", code="invalid_file_path")
161
+
162
+ parsed = urlparse(value)
163
+ if parsed.scheme or parsed.netloc:
164
+ if not (
165
+ parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]
166
+ ):
167
+ raise AppException("Invalid file path", code="invalid_file_path")
168
+ path = parsed.path or ""
169
+ if parsed.query:
170
+ path = f"{path}?{parsed.query}"
171
+ else:
172
+ path = value
173
+
174
+ if not path:
175
+ raise AppException("Invalid file path", code="invalid_file_path")
176
+ if not path.startswith("/"):
177
+ path = f"/{path}"
178
+
179
+ return path
180
+
181
+ async def download_file(self, file_path: str, token: str, media_type: str = "image") -> Tuple[Optional[Path], str]:
182
+ """Download asset to local cache.
183
+
184
+ Args:
185
+ file_path: str, the path of the file to download.
186
+ token: str, the SSO token.
187
+ media_type: str, the media type of the file.
188
+
189
+ Returns:
190
+ Tuple[Optional[Path], str]: The path of the downloaded file and the MIME type.
191
+ """
192
+ async with _get_download_semaphore():
193
+ file_path = self._normalize_path(file_path)
194
+ cache_dir = self.image_dir if media_type == "image" else self.video_dir
195
+ filename = file_path.lstrip("/").replace("/", "-")
196
+ cache_path = cache_dir / filename
197
+
198
+ lock_name = (
199
+ f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}"
200
+ )
201
+ lock_timeout = max(1, int(get_config("asset.download_timeout")))
202
+ async with _file_lock(lock_name, timeout=lock_timeout):
203
+ session = await self.create()
204
+ response = await AssetsDownloadReverse.request(session, token, file_path)
205
+
206
+ tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp")
207
+ try:
208
+ async with aiofiles.open(tmp_path, "wb") as f:
209
+ if hasattr(response, "aiter_content"):
210
+ async for chunk in response.aiter_content():
211
+ if chunk:
212
+ await f.write(chunk)
213
+ else:
214
+ await f.write(response.content)
215
+ os.replace(tmp_path, cache_path)
216
+ finally:
217
+ if tmp_path.exists() and not cache_path.exists():
218
+ try:
219
+ tmp_path.unlink()
220
+ except Exception:
221
+ pass
222
+
223
+ mime = response.headers.get(
224
+ "content-type", "application/octet-stream"
225
+ ).split(";")[0]
226
+ logger.info(f"Downloaded: {file_path}")
227
+
228
+ asyncio.create_task(self._check_limit())
229
+
230
+ return cache_path, mime
231
+
232
+ async def _check_limit(self):
233
+ """Check cache limit and cleanup.
234
+
235
+ Args:
236
+ self: DownloadService, the download service instance.
237
+
238
+ Returns:
239
+ None
240
+ """
241
+ if self._cleanup_running or not get_config("cache.enable_auto_clean"):
242
+ return
243
+
244
+ self._cleanup_running = True
245
+ try:
246
+ try:
247
+ async with _file_lock("cache_cleanup", timeout=5):
248
+ limit_mb = get_config("cache.limit_mb")
249
+ total_size = 0
250
+ all_files: List[Tuple[Path, float, int]] = []
251
+
252
+ for d in [self.image_dir, self.video_dir]:
253
+ if d.exists():
254
+ for f in d.glob("*"):
255
+ if f.is_file():
256
+ try:
257
+ stat = f.stat()
258
+ total_size += stat.st_size
259
+ all_files.append(
260
+ (f, stat.st_mtime, stat.st_size)
261
+ )
262
+ except Exception:
263
+ pass
264
+ current_mb = total_size / 1024 / 1024
265
+
266
+ if current_mb <= limit_mb:
267
+ return
268
+
269
+ logger.info(
270
+ f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..."
271
+ )
272
+ all_files.sort(key=lambda x: x[1])
273
+
274
+ deleted_count = 0
275
+ deleted_size = 0
276
+ target_mb = limit_mb * 0.8
277
+
278
+ for f, _, size in all_files:
279
+ try:
280
+ f.unlink()
281
+ deleted_count += 1
282
+ deleted_size += size
283
+ total_size -= size
284
+ if (total_size / 1024 / 1024) <= target_mb:
285
+ break
286
+ except Exception:
287
+ pass
288
+
289
+ logger.info(
290
+ f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)"
291
+ )
292
+ except Exception as e:
293
+ logger.warning(f"Cache cleanup failed: {e}")
294
+ finally:
295
+ self._cleanup_running = False
296
+
297
+
298
+ __all__ = ["DownloadService"]
app/services/grok/utils/locks.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared locking helpers for assets operations.
3
+ """
4
+
5
+ import asyncio
6
+ import time
7
+ from contextlib import asynccontextmanager
8
+ from pathlib import Path
9
+
10
+ from app.core.config import get_config
11
+ from app.core.storage import DATA_DIR
12
+
13
+ try:
14
+ import fcntl
15
+ except ImportError:
16
+ fcntl = None
17
+
18
+
19
+ LOCK_DIR = DATA_DIR / ".locks"
20
+
21
+ _UPLOAD_SEMAPHORE = None
22
+ _UPLOAD_SEM_VALUE = None
23
+ _DOWNLOAD_SEMAPHORE = None
24
+ _DOWNLOAD_SEM_VALUE = None
25
+
26
+
27
+ def _get_upload_semaphore() -> asyncio.Semaphore:
28
+ """Return global semaphore for upload operations."""
29
+ value = max(1, int(get_config("asset.upload_concurrent")))
30
+
31
+ global _UPLOAD_SEMAPHORE, _UPLOAD_SEM_VALUE
32
+ if _UPLOAD_SEMAPHORE is None or value != _UPLOAD_SEM_VALUE:
33
+ _UPLOAD_SEM_VALUE = value
34
+ _UPLOAD_SEMAPHORE = asyncio.Semaphore(value)
35
+ return _UPLOAD_SEMAPHORE
36
+
37
+
38
+ def _get_download_semaphore() -> asyncio.Semaphore:
39
+ """Return global semaphore for download operations."""
40
+ value = max(1, int(get_config("asset.download_concurrent")))
41
+
42
+ global _DOWNLOAD_SEMAPHORE, _DOWNLOAD_SEM_VALUE
43
+ if _DOWNLOAD_SEMAPHORE is None or value != _DOWNLOAD_SEM_VALUE:
44
+ _DOWNLOAD_SEM_VALUE = value
45
+ _DOWNLOAD_SEMAPHORE = asyncio.Semaphore(value)
46
+ return _DOWNLOAD_SEMAPHORE
47
+
48
+
49
+ @asynccontextmanager
50
+ async def _file_lock(name: str, timeout: int = 10):
51
+ """File lock guard."""
52
+ if fcntl is None:
53
+ yield
54
+ return
55
+
56
+ LOCK_DIR.mkdir(parents=True, exist_ok=True)
57
+ lock_path = Path(LOCK_DIR) / f"{name}.lock"
58
+ fd = None
59
+ locked = False
60
+ start = time.monotonic()
61
+
62
+ try:
63
+ fd = open(lock_path, "a+")
64
+ while True:
65
+ try:
66
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
67
+ locked = True
68
+ break
69
+ except BlockingIOError:
70
+ if time.monotonic() - start >= timeout:
71
+ break
72
+ await asyncio.sleep(0.05)
73
+ if not locked:
74
+ raise TimeoutError(f"Failed to acquire lock: {name}")
75
+ yield
76
+ finally:
77
+ if fd:
78
+ if locked:
79
+ try:
80
+ fcntl.flock(fd, fcntl.LOCK_UN)
81
+ except Exception:
82
+ pass
83
+ fd.close()
84
+
85
+
86
+ __all__ = ["_get_upload_semaphore", "_get_download_semaphore", "_file_lock"]
app/services/grok/utils/process.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 响应处理器基类和通用工具
3
+ """
4
+
5
+ import asyncio
6
+ import time
7
+ from typing import Any, AsyncGenerator, Optional, AsyncIterable, List, TypeVar
8
+
9
+ from app.core.config import get_config
10
+ from app.core.logger import logger
11
+ from app.core.exceptions import StreamIdleTimeoutError
12
+ from app.services.grok.utils.download import DownloadService
13
+
14
+
15
+ T = TypeVar("T")
16
+
17
+
18
+ def _is_http2_error(e: Exception) -> bool:
19
+ """检查是否为 HTTP/2 流错误"""
20
+ err_str = str(e).lower()
21
+ return "http/2" in err_str or "curl: (92)" in err_str or "stream" in err_str
22
+
23
+
24
+ def _normalize_line(line: Any) -> Optional[str]:
25
+ """规范化流式响应行,兼容 SSE data 前缀与空行"""
26
+ if line is None:
27
+ return None
28
+ if isinstance(line, (bytes, bytearray)):
29
+ text = line.decode("utf-8", errors="ignore")
30
+ else:
31
+ text = str(line)
32
+ text = text.strip()
33
+ if not text:
34
+ return None
35
+ if text.startswith("data:"):
36
+ text = text[5:].strip()
37
+ if text == "[DONE]":
38
+ return None
39
+ return text
40
+
41
+
42
+ def _collect_images(obj: Any) -> List[str]:
43
+ """递归收集响应中的图片 URL"""
44
+ urls: List[str] = []
45
+ seen = set()
46
+
47
+ def add(url: str):
48
+ if not url or url in seen:
49
+ return
50
+ seen.add(url)
51
+ urls.append(url)
52
+
53
+ def walk(value: Any):
54
+ if isinstance(value, dict):
55
+ for key, item in value.items():
56
+ if key in {"generatedImageUrls", "imageUrls", "imageURLs"}:
57
+ if isinstance(item, list):
58
+ for url in item:
59
+ if isinstance(url, str):
60
+ add(url)
61
+ elif isinstance(item, str):
62
+ add(item)
63
+ continue
64
+ walk(item)
65
+ elif isinstance(value, list):
66
+ for item in value:
67
+ walk(item)
68
+
69
+ walk(obj)
70
+ return urls
71
+
72
+
73
+ async def _with_idle_timeout(
74
+ iterable: AsyncIterable[T], idle_timeout: float, model: str = ""
75
+ ) -> AsyncGenerator[T, None]:
76
+ """
77
+ 包装异步迭代器,添加空闲超时检测
78
+
79
+ Args:
80
+ iterable: 原始异步迭代器
81
+ idle_timeout: 空闲超时时间(秒),0 表示禁用
82
+ model: 模型名称(用于日志)
83
+ """
84
+ if idle_timeout <= 0:
85
+ async for item in iterable:
86
+ yield item
87
+ return
88
+
89
+ iterator = iterable.__aiter__()
90
+
91
+ async def _maybe_aclose(it):
92
+ aclose = getattr(it, "aclose", None)
93
+ if not aclose:
94
+ return
95
+ try:
96
+ await aclose()
97
+ except Exception:
98
+ pass
99
+
100
+ while True:
101
+ try:
102
+ item = await asyncio.wait_for(iterator.__anext__(), timeout=idle_timeout)
103
+ yield item
104
+ except asyncio.TimeoutError:
105
+ logger.warning(
106
+ f"Stream idle timeout after {idle_timeout}s",
107
+ extra={"model": model, "idle_timeout": idle_timeout},
108
+ )
109
+ await _maybe_aclose(iterator)
110
+ raise StreamIdleTimeoutError(idle_timeout)
111
+ except asyncio.CancelledError:
112
+ await _maybe_aclose(iterator)
113
+ raise
114
+ except StopAsyncIteration:
115
+ break
116
+
117
+
118
+ class BaseProcessor:
119
+ """基础处理器"""
120
+
121
+ def __init__(self, model: str, token: str = ""):
122
+ self.model = model
123
+ self.token = token
124
+ self.created = int(time.time())
125
+ self.app_url = get_config("app.app_url")
126
+ self._dl_service: Optional[DownloadService] = None
127
+
128
+ def _get_dl(self) -> DownloadService:
129
+ """获取下载服务实例(复用)"""
130
+ if self._dl_service is None:
131
+ self._dl_service = DownloadService()
132
+ return self._dl_service
133
+
134
+ async def close(self):
135
+ """释放下载服务资源"""
136
+ if self._dl_service:
137
+ await self._dl_service.close()
138
+ self._dl_service = None
139
+
140
+ async def process_url(self, path: str, media_type: str = "image") -> str:
141
+ """处理资产 URL"""
142
+ dl_service = self._get_dl()
143
+ return await dl_service.resolve_url(path, self.token, media_type)
144
+
145
+
146
+ __all__ = [
147
+ "BaseProcessor",
148
+ "_with_idle_timeout",
149
+ "_normalize_line",
150
+ "_collect_images",
151
+ "_is_http2_error",
152
+ ]
app/services/grok/utils/response.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Response formatting utilities for OpenAI-compatible API responses.
3
+ """
4
+
5
+ import os
6
+ import time
7
+ import uuid
8
+ from typing import Optional
9
+
10
+
11
+ def make_response_id() -> str:
12
+ """Generate a unique response ID."""
13
+ return f"chatcmpl-{int(time.time() * 1000)}{os.urandom(4).hex()}"
14
+
15
+
16
+ def make_chat_chunk(
17
+ response_id: str,
18
+ model: str,
19
+ content: str,
20
+ index: int = 0,
21
+ role: str = "assistant",
22
+ is_final: bool = False,
23
+ ) -> dict:
24
+ """
25
+ Create an OpenAI-compatible chat completion chunk.
26
+
27
+ Args:
28
+ response_id: Unique response ID
29
+ model: Model name
30
+ content: Content to send
31
+ index: Choice index
32
+ role: Role (assistant)
33
+ is_final: Whether this is the final chunk (includes finish_reason)
34
+
35
+ Returns:
36
+ Chat completion chunk dict
37
+ """
38
+ choice: dict = {
39
+ "index": index,
40
+ "delta": {
41
+ "role": role,
42
+ "content": content,
43
+ },
44
+ }
45
+
46
+ if is_final:
47
+ choice["finish_reason"] = "stop"
48
+
49
+ chunk: dict = {
50
+ "id": response_id,
51
+ "object": "chat.completion.chunk",
52
+ "created": int(time.time()),
53
+ "model": model,
54
+ "choices": [choice],
55
+ }
56
+
57
+ if is_final:
58
+ chunk["usage"] = {
59
+ "total_tokens": 0,
60
+ "input_tokens": 0,
61
+ "output_tokens": 0,
62
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
63
+ }
64
+
65
+ return chunk
66
+
67
+
68
+ def make_chat_response(
69
+ model: str,
70
+ content: str,
71
+ response_id: Optional[str] = None,
72
+ index: int = 0,
73
+ usage: Optional[dict] = None,
74
+ ) -> dict:
75
+ """
76
+ Create an OpenAI-compatible non-streaming chat completion response.
77
+
78
+ Args:
79
+ model: Model name
80
+ content: Response content
81
+ response_id: Unique response ID (generated if not provided)
82
+ index: Choice index
83
+ usage: Custom usage dict (defaults to zeros)
84
+
85
+ Returns:
86
+ Chat completion response dict
87
+ """
88
+ if response_id is None:
89
+ response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
90
+
91
+ if usage is None:
92
+ usage = {
93
+ "total_tokens": 0,
94
+ "input_tokens": 0,
95
+ "output_tokens": 0,
96
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
97
+ }
98
+
99
+ return {
100
+ "id": response_id,
101
+ "object": "chat.completion",
102
+ "created": int(time.time()),
103
+ "model": model,
104
+ "choices": [
105
+ {
106
+ "index": index,
107
+ "message": {
108
+ "role": "assistant",
109
+ "content": content,
110
+ "refusal": None,
111
+ },
112
+ "finish_reason": "stop",
113
+ }
114
+ ],
115
+ "usage": usage,
116
+ }
117
+
118
+
119
+ def wrap_image_content(content: str, response_format: str = "url") -> str:
120
+ """
121
+ Wrap image content in markdown format for chat interface.
122
+
123
+ Args:
124
+ content: Image URL or base64 data
125
+ response_format: "url" or "b64_json"/"base64"
126
+
127
+ Returns:
128
+ Markdown-wrapped image content
129
+ """
130
+ if not content:
131
+ return content
132
+
133
+ if response_format == "url":
134
+ return f"![image]({content})"
135
+ else:
136
+ return f"![image](data:image/png;base64,{content})"
137
+
138
+
139
+ __all__ = [
140
+ "make_response_id",
141
+ "make_chat_chunk",
142
+ "make_chat_response",
143
+ "wrap_image_content",
144
+ ]
app/services/grok/utils/retry.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retry helpers for token switching.
3
+ """
4
+
5
+ from typing import Optional, Set
6
+
7
+ from app.core.exceptions import UpstreamException
8
+ from app.services.grok.services.model import ModelService
9
+
10
+
11
+ async def pick_token(
12
+ token_mgr,
13
+ model_id: str,
14
+ tried: Set[str],
15
+ preferred: Optional[str] = None,
16
+ prefer_tags: Optional[Set[str]] = None,
17
+ ) -> Optional[str]:
18
+ if preferred and preferred not in tried:
19
+ return preferred
20
+
21
+ token = None
22
+ for pool_name in ModelService.pool_candidates_for_model(model_id):
23
+ token = token_mgr.get_token(pool_name, exclude=tried, prefer_tags=prefer_tags)
24
+ if token:
25
+ break
26
+
27
+ if not token and not tried:
28
+ result = await token_mgr.refresh_cooling_tokens()
29
+ if result.get("recovered", 0) > 0:
30
+ for pool_name in ModelService.pool_candidates_for_model(model_id):
31
+ token = token_mgr.get_token(pool_name, prefer_tags=prefer_tags)
32
+ if token:
33
+ break
34
+
35
+ return token
36
+
37
+
38
+ def rate_limited(error: Exception) -> bool:
39
+ if not isinstance(error, UpstreamException):
40
+ return False
41
+ status = error.details.get("status") if error.details else None
42
+ code = error.details.get("error_code") if error.details else None
43
+ return status == 429 or code == "rate_limit_exceeded"
44
+
45
+
46
+ def transient_upstream(error: Exception) -> bool:
47
+ """Whether error is likely transient and safe to retry with another token."""
48
+ if not isinstance(error, UpstreamException):
49
+ return False
50
+ details = error.details or {}
51
+ status = details.get("status")
52
+ err = str(details.get("error") or error).lower()
53
+ transient_status = {408, 500, 502, 503, 504}
54
+ if status in transient_status:
55
+ return True
56
+ timeout_markers = (
57
+ "timed out",
58
+ "timeout",
59
+ "connection reset",
60
+ "temporarily unavailable",
61
+ "http2",
62
+ )
63
+ return any(marker in err for marker in timeout_markers)
64
+
65
+
66
+ __all__ = ["pick_token", "rate_limited", "transient_upstream"]
app/services/grok/utils/stream.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 流式响应通用工具
3
+ """
4
+
5
+ from typing import AsyncGenerator
6
+
7
+ from app.core.logger import logger
8
+ from app.services.grok.services.model import ModelService
9
+ from app.services.token import EffortType
10
+
11
+
12
+ async def wrap_stream_with_usage(
13
+ stream: AsyncGenerator, token_mgr, token: str, model: str
14
+ ) -> AsyncGenerator:
15
+ """
16
+ 包装流式响应,在完成时记录使用
17
+
18
+ Args:
19
+ stream: 原始 AsyncGenerator
20
+ token_mgr: TokenManager 实例
21
+ token: Token 字符串
22
+ model: 模型名称
23
+ """
24
+ success = False
25
+ try:
26
+ async for chunk in stream:
27
+ yield chunk
28
+ success = True
29
+ finally:
30
+ if success:
31
+ try:
32
+ model_info = ModelService.get(model)
33
+ effort = (
34
+ EffortType.HIGH
35
+ if (model_info and model_info.cost.value == "high")
36
+ else EffortType.LOW
37
+ )
38
+ await token_mgr.consume(token, effort)
39
+ logger.debug(
40
+ f"Stream completed, recorded usage for token {token[:10]}... (effort={effort.value})"
41
+ )
42
+ except Exception as e:
43
+ logger.warning(f"Failed to record stream usage: {e}")
44
+
45
+
46
+ __all__ = ["wrap_stream_with_usage"]
app/services/grok/utils/tool_call.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool call utilities for OpenAI-compatible function calling.
3
+
4
+ Provides prompt-based emulation of tool calls by injecting tool definitions
5
+ into the system prompt and parsing structured responses.
6
+ """
7
+
8
+ import json
9
+ import re
10
+ import uuid
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+
14
+ def build_tool_prompt(
15
+ tools: List[Dict[str, Any]],
16
+ tool_choice: Optional[Any] = None,
17
+ parallel_tool_calls: bool = True,
18
+ ) -> str:
19
+ """Generate a system prompt block describing available tools.
20
+
21
+ Args:
22
+ tools: List of OpenAI-format tool definitions.
23
+ tool_choice: "auto", "required", "none", or {"type":"function","function":{"name":"..."}}.
24
+ parallel_tool_calls: Whether multiple tool calls are allowed.
25
+
26
+ Returns:
27
+ System prompt string to prepend to the conversation.
28
+ """
29
+ if not tools:
30
+ return ""
31
+
32
+ # tool_choice="none" means don't mention tools at all
33
+ if tool_choice == "none":
34
+ return ""
35
+
36
+ lines = [
37
+ "# Available Tools",
38
+ "",
39
+ "You have access to the following tools. To call a tool, output a <tool_call> block with a JSON object containing \"name\" and \"arguments\".",
40
+ "",
41
+ "Format:",
42
+ "<tool_call>",
43
+ '{"name": "function_name", "arguments": {"param": "value"}}',
44
+ "</tool_call>",
45
+ "",
46
+ ]
47
+
48
+ if parallel_tool_calls:
49
+ lines.append("You may make multiple tool calls in a single response by using multiple <tool_call> blocks.")
50
+ lines.append("")
51
+
52
+ # Describe each tool
53
+ lines.append("## Tool Definitions")
54
+ lines.append("")
55
+ for tool in tools:
56
+ if tool.get("type") != "function":
57
+ continue
58
+ func = tool.get("function", {})
59
+ name = func.get("name", "")
60
+ desc = func.get("description", "")
61
+ params = func.get("parameters", {})
62
+
63
+ lines.append(f"### {name}")
64
+ if desc:
65
+ lines.append(f"{desc}")
66
+ if params:
67
+ lines.append(f"Parameters: {json.dumps(params, ensure_ascii=False)}")
68
+ lines.append("")
69
+
70
+ # Handle tool_choice directives
71
+ if tool_choice == "required":
72
+ lines.append("IMPORTANT: You MUST call at least one tool in your response. Do not respond with only text.")
73
+ elif isinstance(tool_choice, dict):
74
+ func_info = tool_choice.get("function", {})
75
+ forced_name = func_info.get("name", "")
76
+ if forced_name:
77
+ lines.append(f"IMPORTANT: You MUST call the tool \"{forced_name}\" in your response.")
78
+ else:
79
+ # "auto" or default
80
+ lines.append("Decide whether to call a tool based on the user's request. If you don't need a tool, respond normally with text only.")
81
+
82
+ lines.append("")
83
+ lines.append("When you call a tool, you may include text before or after the <tool_call> blocks, but the tool call blocks must be valid JSON.")
84
+
85
+ return "\n".join(lines)
86
+
87
+
88
+ _TOOL_CALL_RE = re.compile(
89
+ r"<tool_call>\s*(.*?)\s*</tool_call>",
90
+ re.DOTALL,
91
+ )
92
+
93
+
94
+ def _strip_code_fences(text: str) -> str:
95
+ if not text:
96
+ return text
97
+ cleaned = text.strip()
98
+ if cleaned.startswith("```"):
99
+ cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
100
+ cleaned = re.sub(r"\s*```$", "", cleaned)
101
+ return cleaned.strip()
102
+
103
+
104
+ def _extract_json_object(text: str) -> str:
105
+ if not text:
106
+ return text
107
+ start = text.find("{")
108
+ if start == -1:
109
+ return text
110
+ end = text.rfind("}")
111
+ if end == -1:
112
+ return text[start:]
113
+ if end < start:
114
+ return text
115
+ return text[start : end + 1]
116
+
117
+
118
+ def _remove_trailing_commas(text: str) -> str:
119
+ if not text:
120
+ return text
121
+ return re.sub(r",\s*([}\]])", r"\1", text)
122
+
123
+
124
+ def _balance_braces(text: str) -> str:
125
+ if not text:
126
+ return text
127
+ open_count = 0
128
+ close_count = 0
129
+ in_string = False
130
+ escape = False
131
+ for ch in text:
132
+ if escape:
133
+ escape = False
134
+ continue
135
+ if ch == "\\" and in_string:
136
+ escape = True
137
+ continue
138
+ if ch == '"':
139
+ in_string = not in_string
140
+ continue
141
+ if in_string:
142
+ continue
143
+ if ch == "{":
144
+ open_count += 1
145
+ elif ch == "}":
146
+ close_count += 1
147
+ if open_count > close_count:
148
+ text = text + ("}" * (open_count - close_count))
149
+ return text
150
+
151
+
152
+ def _repair_json(text: str) -> Optional[Any]:
153
+ if not text:
154
+ return None
155
+ cleaned = _strip_code_fences(text)
156
+ cleaned = _extract_json_object(cleaned)
157
+ cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
158
+ cleaned = cleaned.replace("\n", " ")
159
+ cleaned = _remove_trailing_commas(cleaned)
160
+ cleaned = _balance_braces(cleaned)
161
+ try:
162
+ return json.loads(cleaned)
163
+ except json.JSONDecodeError:
164
+ return None
165
+
166
+
167
+ def parse_tool_call_block(
168
+ raw_json: str,
169
+ tools: Optional[List[Dict[str, Any]]] = None,
170
+ ) -> Optional[Dict[str, Any]]:
171
+ if not raw_json:
172
+ return None
173
+ parsed = None
174
+ try:
175
+ parsed = json.loads(raw_json)
176
+ except json.JSONDecodeError:
177
+ parsed = _repair_json(raw_json)
178
+ if not isinstance(parsed, dict):
179
+ return None
180
+
181
+ name = parsed.get("name")
182
+ arguments = parsed.get("arguments", {})
183
+ if not name:
184
+ return None
185
+
186
+ valid_names = set()
187
+ if tools:
188
+ for tool in tools:
189
+ func = tool.get("function", {})
190
+ tool_name = func.get("name")
191
+ if tool_name:
192
+ valid_names.add(tool_name)
193
+ if valid_names and name not in valid_names:
194
+ return None
195
+
196
+ if isinstance(arguments, dict):
197
+ arguments_str = json.dumps(arguments, ensure_ascii=False)
198
+ elif isinstance(arguments, str):
199
+ arguments_str = arguments
200
+ else:
201
+ arguments_str = json.dumps(arguments, ensure_ascii=False)
202
+
203
+ return {
204
+ "id": f"call_{uuid.uuid4().hex[:24]}",
205
+ "type": "function",
206
+ "function": {"name": name, "arguments": arguments_str},
207
+ }
208
+
209
+
210
+ def parse_tool_calls(
211
+ content: str,
212
+ tools: Optional[List[Dict[str, Any]]] = None,
213
+ ) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
214
+ """Parse tool call blocks from model output.
215
+
216
+ Detects ``<tool_call>...</tool_call>`` blocks, parses JSON from each block,
217
+ and returns OpenAI-format tool call objects.
218
+
219
+ Args:
220
+ content: Raw model output text.
221
+ tools: Optional list of tool definitions for name validation.
222
+
223
+ Returns:
224
+ Tuple of (text_content, tool_calls_list).
225
+ - text_content: text outside <tool_call> blocks (None if empty).
226
+ - tool_calls_list: list of OpenAI tool call dicts, or None if no calls found.
227
+ """
228
+ if not content:
229
+ return content, None
230
+
231
+ matches = list(_TOOL_CALL_RE.finditer(content))
232
+ if not matches:
233
+ return content, None
234
+
235
+ tool_calls = []
236
+ for match in matches:
237
+ raw_json = match.group(1).strip()
238
+ tool_call = parse_tool_call_block(raw_json, tools)
239
+ if tool_call:
240
+ tool_calls.append(tool_call)
241
+
242
+ if not tool_calls:
243
+ return content, None
244
+
245
+ # Extract text outside of tool_call blocks
246
+ text_parts = []
247
+ last_end = 0
248
+ for match in matches:
249
+ before = content[last_end:match.start()]
250
+ if before.strip():
251
+ text_parts.append(before.strip())
252
+ last_end = match.end()
253
+ trailing = content[last_end:]
254
+ if trailing.strip():
255
+ text_parts.append(trailing.strip())
256
+
257
+ text_content = "\n".join(text_parts) if text_parts else None
258
+
259
+ return text_content, tool_calls
260
+
261
+
262
+ def format_tool_history(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
263
+ """Convert assistant messages with tool_calls and tool role messages into text format.
264
+
265
+ Since Grok's web API only accepts a single message string, this converts
266
+ tool-related messages back to a text representation for multi-turn conversations.
267
+
268
+ Args:
269
+ messages: List of OpenAI-format messages that may contain tool_calls and tool roles.
270
+
271
+ Returns:
272
+ List of messages with tool content converted to text format.
273
+ """
274
+ result = []
275
+ for msg in messages:
276
+ role = msg.get("role", "")
277
+ content = msg.get("content")
278
+ tool_calls = msg.get("tool_calls")
279
+ tool_call_id = msg.get("tool_call_id")
280
+ name = msg.get("name")
281
+
282
+ if role == "assistant" and tool_calls:
283
+ # Convert assistant tool_calls to text representation
284
+ parts = []
285
+ if content:
286
+ parts.append(content if isinstance(content, str) else str(content))
287
+ for tc in tool_calls:
288
+ func = tc.get("function", {})
289
+ tc_name = func.get("name", "")
290
+ tc_args = func.get("arguments", "{}")
291
+ tc_id = tc.get("id", "")
292
+ parts.append(f'<tool_call>{{"name":"{tc_name}","arguments":{tc_args}}}</tool_call>')
293
+ result.append({
294
+ "role": "assistant",
295
+ "content": "\n".join(parts),
296
+ })
297
+
298
+ elif role == "tool":
299
+ # Convert tool result to text format
300
+ tool_name = name or "unknown"
301
+ call_id = tool_call_id or ""
302
+ content_str = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) if content else ""
303
+ result.append({
304
+ "role": "user",
305
+ "content": f"tool ({tool_name}, {call_id}): {content_str}",
306
+ })
307
+
308
+ else:
309
+ result.append(msg)
310
+
311
+ return result
312
+
313
+
314
+ __all__ = [
315
+ "build_tool_prompt",
316
+ "parse_tool_calls",
317
+ "format_tool_history",
318
+ "parse_tool_call_block",
319
+ ]