Spaces:
Running
Running
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import uuid
|
5 |
+
import base64
|
6 |
+
import re
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
from datetime import datetime, timezone
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import httpx
|
13 |
+
import uvicorn
|
14 |
+
from fastapi import (
|
15 |
+
BackgroundTasks,
|
16 |
+
FastAPI,
|
17 |
+
HTTPException,
|
18 |
+
Request,
|
19 |
+
Response,
|
20 |
+
status,
|
21 |
+
)
|
22 |
+
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
23 |
+
from fastapi.middleware.cors import CORSMiddleware
|
24 |
+
from fastapi.staticfiles import StaticFiles
|
25 |
+
|
26 |
+
from bearer_token import BearerTokenGenerator
|
27 |
+
|
28 |
+
# 模型列表
|
29 |
+
MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
|
30 |
+
|
31 |
+
# 默认端口
|
32 |
+
INITIAL_PORT = 3000
|
33 |
+
|
34 |
+
# 外部API的URL
|
35 |
+
EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
|
36 |
+
|
37 |
+
# 初始化FastAPI应用
|
38 |
+
app = FastAPI()
|
39 |
+
|
40 |
+
# 添加CORS中间件
|
41 |
+
app.add_middleware(
|
42 |
+
CORSMiddleware,
|
43 |
+
allow_origins=["*"], # 允许所有来源
|
44 |
+
allow_credentials=True,
|
45 |
+
allow_methods=["GET", "POST", "OPTIONS"], # 允许GET, POST, OPTIONS方法
|
46 |
+
allow_headers=["Content-Type", "Authorization"], # 允许的头部
|
47 |
+
)
|
48 |
+
|
49 |
+
# 挂载静态文件路由以提供 images 目录的内容
|
50 |
+
app.mount("/images", StaticFiles(directory="images"), name="images")
|
51 |
+
|
52 |
+
# 辅助函数
|
53 |
+
def send_error_response(message: str, status_code: int = 400):
|
54 |
+
"""构建错误响应,并确保包含CORS头"""
|
55 |
+
error_json = {"error": message}
|
56 |
+
headers = {
|
57 |
+
"Access-Control-Allow-Origin": "*",
|
58 |
+
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
59 |
+
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
60 |
+
}
|
61 |
+
return JSONResponse(status_code=status_code, content=error_json, headers=headers)
|
62 |
+
|
63 |
+
def extract_path_from_markdown(markdown: str) -> Optional[str]:
|
64 |
+
"""
|
65 |
+
提取 Markdown 图片链接中的路径,匹配以 https://spc.unk/ 开头的 URL
|
66 |
+
"""
|
67 |
+
pattern = re.compile(r'!\[.*?\]\(https://spc\.unk/(.*?)\)')
|
68 |
+
match = pattern.search(markdown)
|
69 |
+
if match:
|
70 |
+
return match.group(1)
|
71 |
+
return None
|
72 |
+
|
73 |
+
async def fetch_get_url_from_storage(storage_url: str) -> Optional[str]:
|
74 |
+
"""
|
75 |
+
从 storage URL 获取 JSON 并提取 getUrl
|
76 |
+
"""
|
77 |
+
async with httpx.AsyncClient() as client:
|
78 |
+
try:
|
79 |
+
response = await client.get(storage_url)
|
80 |
+
if response.status_code != 200:
|
81 |
+
print(f"获取 storage URL 失败,状态码: {response.status_code}")
|
82 |
+
return None
|
83 |
+
json_response = response.json()
|
84 |
+
return json_response.get("getUrl")
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error fetching getUrl from storage: {e}")
|
87 |
+
return None
|
88 |
+
|
89 |
+
async def download_image(image_url: str) -> Optional[bytes]:
|
90 |
+
"""
|
91 |
+
下载图像
|
92 |
+
"""
|
93 |
+
async with httpx.AsyncClient() as client:
|
94 |
+
try:
|
95 |
+
response = await client.get(image_url)
|
96 |
+
if response.status_code == 200:
|
97 |
+
return response.content
|
98 |
+
else:
|
99 |
+
print(f"下载图像失败,状态码: {response.status_code}")
|
100 |
+
return None
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error downloading image: {e}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
def save_base64_image(base64_str: str, images_dir: str = "images") -> str:
|
106 |
+
"""
|
107 |
+
将Base64编码的图片保存到images目录,返回文件名
|
108 |
+
"""
|
109 |
+
if not os.path.exists(images_dir):
|
110 |
+
os.makedirs(images_dir)
|
111 |
+
image_data = base64.b64decode(base64_str)
|
112 |
+
filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
|
113 |
+
file_path = os.path.join(images_dir, filename)
|
114 |
+
with open(file_path, "wb") as f:
|
115 |
+
f.write(image_data)
|
116 |
+
return filename
|
117 |
+
|
118 |
+
def is_base64_image(url: str) -> bool:
|
119 |
+
"""
|
120 |
+
判断URL是否为Base64编码的图片
|
121 |
+
"""
|
122 |
+
return url.startswith("data:image/")
|
123 |
+
|
124 |
+
# 根路径GET请求处理
|
125 |
+
@app.get("/", response_class=HTMLResponse)
|
126 |
+
async def read_root():
|
127 |
+
"""返回欢迎页面"""
|
128 |
+
html_content = """
|
129 |
+
<html>
|
130 |
+
<head>
|
131 |
+
<title>Welcome to API</title>
|
132 |
+
</head>
|
133 |
+
<body>
|
134 |
+
<h1>Welcome to API</h1>
|
135 |
+
<p>This API is used to interact with the ChatGPT model. You can send messages to the model and receive responses.</p>
|
136 |
+
</body>
|
137 |
+
</html>
|
138 |
+
"""
|
139 |
+
return HTMLResponse(content=html_content, status_code=200)
|
140 |
+
|
141 |
+
# 聊天完成处理
|
142 |
+
@app.post("/v1/chat/completions")
|
143 |
+
async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
144 |
+
"""
|
145 |
+
处理聊天完成请求
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
request_body = await request.json()
|
149 |
+
except json.JSONDecodeError:
|
150 |
+
raise HTTPException(status_code=400, detail="Invalid JSON")
|
151 |
+
|
152 |
+
# 打印接收到的请求
|
153 |
+
print("Received Completion JSON:", json.dumps(request_body, ensure_ascii=False))
|
154 |
+
|
155 |
+
# 处理消息内容
|
156 |
+
messages = request_body.get("messages", [])
|
157 |
+
temperature = request_body.get("temperature", 1.0)
|
158 |
+
top_p = request_body.get("top_p", 1.0)
|
159 |
+
max_tokens = request_body.get("max_tokens", 8000)
|
160 |
+
model = request_body.get("model", "gpt-4o")
|
161 |
+
is_stream = request_body.get("stream", False) # 获取 stream 字段
|
162 |
+
|
163 |
+
has_image = False
|
164 |
+
has_text = False
|
165 |
+
|
166 |
+
# 清理和提取消息内容
|
167 |
+
cleaned_messages = []
|
168 |
+
for message in messages:
|
169 |
+
content = message.get("content", "")
|
170 |
+
if isinstance(content, list):
|
171 |
+
text_parts = []
|
172 |
+
images = []
|
173 |
+
for item in content:
|
174 |
+
if "text" in item:
|
175 |
+
text_parts.append(item.get("text", ""))
|
176 |
+
elif "image_url" in item:
|
177 |
+
has_image = True
|
178 |
+
image_info = item.get("image_url", {})
|
179 |
+
url = image_info.get("url", "")
|
180 |
+
if is_base64_image(url):
|
181 |
+
# 解码并保存图片
|
182 |
+
base64_str = url.split(",")[1]
|
183 |
+
filename = save_base64_image(base64_str)
|
184 |
+
base_url = app.state.base_url
|
185 |
+
image_url = f"{base_url}/images/{filename}"
|
186 |
+
images.append({"data": image_url})
|
187 |
+
else:
|
188 |
+
images.append({"data": url})
|
189 |
+
extracted_content = " ".join(text_parts).strip()
|
190 |
+
if extracted_content:
|
191 |
+
has_text = True
|
192 |
+
message["content"] = extracted_content
|
193 |
+
if images:
|
194 |
+
message["images"] = images
|
195 |
+
cleaned_messages.append(message)
|
196 |
+
print("Extracted:", extracted_content)
|
197 |
+
else:
|
198 |
+
if images:
|
199 |
+
has_image = True
|
200 |
+
message["content"] = ""
|
201 |
+
message["images"] = images
|
202 |
+
cleaned_messages.append(message)
|
203 |
+
print("Extracted image only.")
|
204 |
+
else:
|
205 |
+
print("Deleted message with empty content.")
|
206 |
+
elif isinstance(content, str):
|
207 |
+
content_str = content.strip()
|
208 |
+
if content_str:
|
209 |
+
has_text = True
|
210 |
+
message["content"] = content_str
|
211 |
+
cleaned_messages.append(message)
|
212 |
+
print("Retained content:", content_str)
|
213 |
+
else:
|
214 |
+
print("Deleted message with empty content.")
|
215 |
+
else:
|
216 |
+
print("Deleted non-expected type of content message.")
|
217 |
+
|
218 |
+
if not cleaned_messages:
|
219 |
+
raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
|
220 |
+
|
221 |
+
# 验证模型
|
222 |
+
if model not in MODELS:
|
223 |
+
model = "gpt-4o"
|
224 |
+
|
225 |
+
# 构建新的请求JSON
|
226 |
+
new_request_json = {
|
227 |
+
"function_image_gen": False,
|
228 |
+
"function_web_search": True,
|
229 |
+
"max_tokens": max_tokens,
|
230 |
+
"model": model,
|
231 |
+
"source": "chat/free",
|
232 |
+
"temperature": temperature,
|
233 |
+
"top_p": top_p,
|
234 |
+
"messages": cleaned_messages,
|
235 |
+
}
|
236 |
+
|
237 |
+
modified_request_body = json.dumps(new_request_json, ensure_ascii=False)
|
238 |
+
print("Modified Request JSON:", modified_request_body)
|
239 |
+
|
240 |
+
# 获取Bearer Token
|
241 |
+
tmp_token = BearerTokenGenerator.get_bearer(modified_request_body)
|
242 |
+
if not tmp_token:
|
243 |
+
raise HTTPException(status_code=500, detail="无法生成 Bearer Token")
|
244 |
+
|
245 |
+
bearer_token, formatted_date = tmp_token
|
246 |
+
|
247 |
+
headers = {
|
248 |
+
"Date": formatted_date,
|
249 |
+
"Client-time-zone": "-05:00",
|
250 |
+
"Authorization": bearer_token,
|
251 |
+
"User-Agent": "ChatOn_Android/1.53.502",
|
252 |
+
"Accept-Language": "en-US",
|
253 |
+
"X-Cl-Options": "hb",
|
254 |
+
"Content-Type": "application/json; charset=UTF-8",
|
255 |
+
}
|
256 |
+
|
257 |
+
if is_stream:
|
258 |
+
# 流式响应处理
|
259 |
+
async def event_generator():
|
260 |
+
async with httpx.AsyncClient(timeout=None) as client_stream:
|
261 |
+
try:
|
262 |
+
async with client_stream.stream("POST", EXTERNAL_API_URL, headers=headers, content=modified_request_body) as streamed_response:
|
263 |
+
async for line in streamed_response.aiter_lines():
|
264 |
+
if line.startswith("data: "):
|
265 |
+
data = line[6:].strip()
|
266 |
+
if data == "[DONE]":
|
267 |
+
# 通知客户端流结束
|
268 |
+
yield "data: [DONE]\n\n"
|
269 |
+
break
|
270 |
+
try:
|
271 |
+
sse_json = json.loads(data)
|
272 |
+
if "choices" in sse_json:
|
273 |
+
for choice in sse_json["choices"]:
|
274 |
+
delta = choice.get("delta", {})
|
275 |
+
content = delta.get("content")
|
276 |
+
if content:
|
277 |
+
new_sse_json = {
|
278 |
+
"choices": [
|
279 |
+
{
|
280 |
+
"index": choice.get("index", 0),
|
281 |
+
"delta": {"content": content},
|
282 |
+
}
|
283 |
+
],
|
284 |
+
"created": sse_json.get(
|
285 |
+
"created", int(datetime.now(timezone.utc).timestamp())
|
286 |
+
),
|
287 |
+
"id": sse_json.get(
|
288 |
+
"id", str(uuid.uuid4())
|
289 |
+
),
|
290 |
+
"model": sse_json.get("model", "gpt-4o"),
|
291 |
+
"system_fingerprint": f"fp_{uuid.uuid4().hex[:12]}",
|
292 |
+
}
|
293 |
+
new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
|
294 |
+
yield new_sse_line
|
295 |
+
except json.JSONDecodeError:
|
296 |
+
print("JSON解析错误")
|
297 |
+
continue
|
298 |
+
except httpx.RequestError as exc:
|
299 |
+
print(f"外部API请求失败: {exc}")
|
300 |
+
yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
|
301 |
+
|
302 |
+
return StreamingResponse(
|
303 |
+
event_generator(),
|
304 |
+
media_type="text/event-stream",
|
305 |
+
headers={
|
306 |
+
"Cache-Control": "no-cache",
|
307 |
+
"Connection": "keep-alive",
|
308 |
+
# CORS头已通过中间件处理,无需在这里重复添加
|
309 |
+
},
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
# 非流式响应处理
|
313 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
314 |
+
try:
|
315 |
+
response = await client.post(
|
316 |
+
EXTERNAL_API_URL,
|
317 |
+
headers=headers,
|
318 |
+
content=modified_request_body,
|
319 |
+
timeout=None
|
320 |
+
)
|
321 |
+
|
322 |
+
if response.status_code != 200:
|
323 |
+
raise HTTPException(
|
324 |
+
status_code=response.status_code,
|
325 |
+
detail=f"API 错误: {response.status_code}",
|
326 |
+
)
|
327 |
+
|
328 |
+
sse_lines = response.text.splitlines()
|
329 |
+
content_builder = ""
|
330 |
+
images_urls = []
|
331 |
+
|
332 |
+
for line in sse_lines:
|
333 |
+
if line.startswith("data: "):
|
334 |
+
data = line[6:].strip()
|
335 |
+
if data == "[DONE]":
|
336 |
+
break
|
337 |
+
try:
|
338 |
+
sse_json = json.loads(data)
|
339 |
+
if "choices" in sse_json:
|
340 |
+
for choice in sse_json["choices"]:
|
341 |
+
if "delta" in choice:
|
342 |
+
delta = choice["delta"]
|
343 |
+
if "content" in delta:
|
344 |
+
content_builder += delta["content"]
|
345 |
+
except json.JSONDecodeError:
|
346 |
+
print("JSON解析错误")
|
347 |
+
continue
|
348 |
+
|
349 |
+
openai_response = {
|
350 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
351 |
+
"object": "chat.completion",
|
352 |
+
"created": int(datetime.now(timezone.utc).timestamp()),
|
353 |
+
"model": model,
|
354 |
+
"choices": [
|
355 |
+
{
|
356 |
+
"index": 0,
|
357 |
+
"message": {
|
358 |
+
"role": "assistant",
|
359 |
+
"content": content_builder,
|
360 |
+
},
|
361 |
+
"finish_reason": "stop",
|
362 |
+
}
|
363 |
+
],
|
364 |
+
}
|
365 |
+
|
366 |
+
# 处理图片(如果有)
|
367 |
+
if has_image:
|
368 |
+
images = []
|
369 |
+
for message in cleaned_messages:
|
370 |
+
if "images" in message:
|
371 |
+
for img in message["images"]:
|
372 |
+
images.append({"data": img["data"]})
|
373 |
+
openai_response["choices"][0]["message"]["images"] = images
|
374 |
+
|
375 |
+
return JSONResponse(content=openai_response, status_code=200)
|
376 |
+
except httpx.RequestError as exc:
|
377 |
+
raise HTTPException(status_code=500, detail=f"请求失败: {str(exc)}")
|
378 |
+
except Exception as exc:
|
379 |
+
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
|
380 |
+
|
381 |
+
# 图像生成处理
|
382 |
+
@app.post("/v1/images/generations")
|
383 |
+
async def images_generations(request: Request):
|
384 |
+
"""
|
385 |
+
处理图像生成请求
|
386 |
+
"""
|
387 |
+
try:
|
388 |
+
request_body = await request.json()
|
389 |
+
except json.JSONDecodeError:
|
390 |
+
return send_error_response("Invalid JSON", status_code=400)
|
391 |
+
|
392 |
+
print("Received Image Generations JSON:", json.dumps(request_body, ensure_ascii=False))
|
393 |
+
|
394 |
+
# 验证必需的字段
|
395 |
+
if "prompt" not in request_body:
|
396 |
+
return send_error_response("缺少必需的字段: prompt", status_code=400)
|
397 |
+
|
398 |
+
user_prompt = request_body.get("prompt", "").strip()
|
399 |
+
response_format = request_body.get("response_format", "b64_json").strip()
|
400 |
+
|
401 |
+
if not user_prompt:
|
402 |
+
return send_error_response("Prompt 不能为空。", status_code=400)
|
403 |
+
|
404 |
+
print(f"Prompt: {user_prompt}")
|
405 |
+
|
406 |
+
# 构建新的 TextToImage JSON 请求体
|
407 |
+
text_to_image_json = {
|
408 |
+
"function_image_gen": True,
|
409 |
+
"function_web_search": True,
|
410 |
+
"image_aspect_ratio": "1:1",
|
411 |
+
"image_style": "photographic", # 暂时固定 image_style
|
412 |
+
"max_tokens": 8000,
|
413 |
+
"messages": [
|
414 |
+
{
|
415 |
+
"content": "You are a helpful artist, please based on imagination draw a picture.",
|
416 |
+
"role": "system"
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"content": "Draw: " + user_prompt,
|
420 |
+
"role": "user"
|
421 |
+
}
|
422 |
+
],
|
423 |
+
"model": "gpt-4o", # 固定 model,只能gpt-4o或gpt-4o-mini
|
424 |
+
"source": "chat/pro_image" # 固定 source
|
425 |
+
}
|
426 |
+
|
427 |
+
modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
|
428 |
+
print("Modified Request JSON:", modified_request_body)
|
429 |
+
|
430 |
+
# 获取Bearer Token
|
431 |
+
tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
|
432 |
+
if not tmp_token:
|
433 |
+
return send_error_response("无法生成 Bearer Token", status_code=500)
|
434 |
+
|
435 |
+
bearer_token, formatted_date = tmp_token
|
436 |
+
|
437 |
+
headers = {
|
438 |
+
"Date": formatted_date,
|
439 |
+
"Client-time-zone": "-05:00",
|
440 |
+
"Authorization": bearer_token,
|
441 |
+
"User-Agent": "ChatOn_Android/1.53.502",
|
442 |
+
"Accept-Language": "en-US",
|
443 |
+
"X-Cl-Options": "hb",
|
444 |
+
"Content-Type": "application/json; charset=UTF-8",
|
445 |
+
}
|
446 |
+
|
447 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
448 |
+
try:
|
449 |
+
response = await client.post(
|
450 |
+
EXTERNAL_API_URL, headers=headers, content=modified_request_body, timeout=None
|
451 |
+
)
|
452 |
+
if response.status_code != 200:
|
453 |
+
return send_error_response(f"API 错误: {response.status_code}", status_code=500)
|
454 |
+
|
455 |
+
# 初始化用于拼接 URL 的字符串
|
456 |
+
url_builder = ""
|
457 |
+
|
458 |
+
# 读取 SSE 流并拼接 URL
|
459 |
+
async for line in response.aiter_lines():
|
460 |
+
if line.startswith("data: "):
|
461 |
+
data = line[6:].strip()
|
462 |
+
if data == "[DONE]":
|
463 |
+
break
|
464 |
+
try:
|
465 |
+
sse_json = json.loads(data)
|
466 |
+
if "choices" in sse_json:
|
467 |
+
for choice in sse_json["choices"]:
|
468 |
+
delta = choice.get("delta", {})
|
469 |
+
content = delta.get("content")
|
470 |
+
if content:
|
471 |
+
url_builder += content
|
472 |
+
except json.JSONDecodeError:
|
473 |
+
print("JSON解析错误")
|
474 |
+
continue
|
475 |
+
|
476 |
+
image_markdown = url_builder
|
477 |
+
# Step 1: 检查Markdown文本是否为空
|
478 |
+
if not image_markdown:
|
479 |
+
print("无法从 SSE 流中构建图像 Markdown。")
|
480 |
+
return send_error_response("无法从 SSE 流中构建图像 Markdown。", status_code=500)
|
481 |
+
|
482 |
+
# Step 2, 3, 4, 5: 处理图像
|
483 |
+
extracted_path = extract_path_from_markdown(image_markdown)
|
484 |
+
if not extracted_path:
|
485 |
+
print("无法从 Markdown 中提取路径。")
|
486 |
+
return send_error_response("无法从 Markdown 中提取路径。", status_code=500)
|
487 |
+
|
488 |
+
print(f"提取的路径: {extracted_path}")
|
489 |
+
|
490 |
+
# Step 5: 拼接最终的存储URL
|
491 |
+
storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
|
492 |
+
print(f"存储URL: {storage_url}")
|
493 |
+
|
494 |
+
# 获取最终下载URL
|
495 |
+
final_download_url = await fetch_get_url_from_storage(storage_url)
|
496 |
+
if not final_download_url:
|
497 |
+
return send_error_response("无法从 storage URL 获取最终下载链接。", status_code=500)
|
498 |
+
|
499 |
+
print(f"Final Download URL: {final_download_url}")
|
500 |
+
|
501 |
+
# 下载图像
|
502 |
+
image_bytes = await download_image(final_download_url)
|
503 |
+
if not image_bytes:
|
504 |
+
return send_error_response("无法从 URL 下载图像。", status_code=500)
|
505 |
+
|
506 |
+
# 转换为 Base64
|
507 |
+
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
508 |
+
|
509 |
+
# 将图片保存到images目录并构建可访问的URL
|
510 |
+
filename = save_base64_image(image_base64)
|
511 |
+
base_url = app.state.base_url
|
512 |
+
accessible_url = f"{base_url}/images/{filename}"
|
513 |
+
|
514 |
+
# 根据 response_format 返回相应的响应
|
515 |
+
if response_format.lower() == "b64_json":
|
516 |
+
response_json = {
|
517 |
+
"data": [
|
518 |
+
{
|
519 |
+
"b64_json": image_base64
|
520 |
+
}
|
521 |
+
]
|
522 |
+
}
|
523 |
+
return JSONResponse(content=response_json, status_code=200)
|
524 |
+
else:
|
525 |
+
# 构建包含可访问URL的响应
|
526 |
+
response_json = {
|
527 |
+
"data": [
|
528 |
+
{
|
529 |
+
"url": accessible_url
|
530 |
+
}
|
531 |
+
]
|
532 |
+
}
|
533 |
+
return JSONResponse(content=response_json, status_code=200)
|
534 |
+
except httpx.RequestError as exc:
|
535 |
+
print(f"请求失败: {exc}")
|
536 |
+
return send_error_response(f"请求失败: {str(exc)}", status_code=500)
|
537 |
+
except Exception as exc:
|
538 |
+
print(f"内部服务器错误: {exc}")
|
539 |
+
return send_error_response(f"内部服务器错误: {str(exc)}", status_code=500)
|
540 |
+
|
541 |
+
# 运行服务器
|
542 |
+
def main():
|
543 |
+
parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
|
544 |
+
parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
|
545 |
+
parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
|
546 |
+
args = parser.parse_args()
|
547 |
+
|
548 |
+
base_url = args.base_url
|
549 |
+
port = args.port
|
550 |
+
|
551 |
+
# 确保 images 目录存在
|
552 |
+
if not os.path.exists("images"):
|
553 |
+
os.makedirs("images")
|
554 |
+
|
555 |
+
# 设置 FastAPI 应用的 state
|
556 |
+
app.state.base_url = base_url
|
557 |
+
|
558 |
+
print(f"Server started on port {port} with base_url: {base_url}")
|
559 |
+
|
560 |
+
# 运行FastAPI应用
|
561 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
562 |
+
|
563 |
+
async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
|
564 |
+
"""查找可用的端口号"""
|
565 |
+
for port in range(start_port, end_port + 1):
|
566 |
+
try:
|
567 |
+
server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
|
568 |
+
server.close()
|
569 |
+
await server.wait_closed()
|
570 |
+
return port
|
571 |
+
except OSError:
|
572 |
+
continue
|
573 |
+
raise RuntimeError(f"No available ports between {start_port} and {end_port}")
|
574 |
+
|
575 |
+
if __name__ == "__main__":
|
576 |
+
main()
|