smgc commited on
Commit
e80aec8
1 Parent(s): 64f391f

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +576 -0
main.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import sys
4
+ import uuid
5
+ import base64
6
+ import re
7
+ import os
8
+ import argparse
9
+ from datetime import datetime, timezone
10
+ from typing import List, Optional
11
+
12
+ import httpx
13
+ import uvicorn
14
+ from fastapi import (
15
+ BackgroundTasks,
16
+ FastAPI,
17
+ HTTPException,
18
+ Request,
19
+ Response,
20
+ status,
21
+ )
22
+ from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from fastapi.staticfiles import StaticFiles
25
+
26
+ from bearer_token import BearerTokenGenerator
27
+
28
+ # 模型列表
29
+ MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
30
+
31
+ # 默认端口
32
+ INITIAL_PORT = 3000
33
+
34
+ # 外部API的URL
35
+ EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
36
+
37
+ # 初始化FastAPI应用
38
+ app = FastAPI()
39
+
40
+ # 添加CORS中间件
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"], # 允许所有来源
44
+ allow_credentials=True,
45
+ allow_methods=["GET", "POST", "OPTIONS"], # 允许GET, POST, OPTIONS方法
46
+ allow_headers=["Content-Type", "Authorization"], # 允许的头部
47
+ )
48
+
49
+ # 挂载静态文件路由以提供 images 目录的内容
50
+ app.mount("/images", StaticFiles(directory="images"), name="images")
51
+
52
+ # 辅助函数
53
+ def send_error_response(message: str, status_code: int = 400):
54
+ """构建错误响应,并确保包含CORS头"""
55
+ error_json = {"error": message}
56
+ headers = {
57
+ "Access-Control-Allow-Origin": "*",
58
+ "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
59
+ "Access-Control-Allow-Headers": "Content-Type, Authorization",
60
+ }
61
+ return JSONResponse(status_code=status_code, content=error_json, headers=headers)
62
+
63
+ def extract_path_from_markdown(markdown: str) -> Optional[str]:
64
+ """
65
+ 提取 Markdown 图片链接中的路径,匹配以 https://spc.unk/ 开头的 URL
66
+ """
67
+ pattern = re.compile(r'!\[.*?\]\(https://spc\.unk/(.*?)\)')
68
+ match = pattern.search(markdown)
69
+ if match:
70
+ return match.group(1)
71
+ return None
72
+
73
+ async def fetch_get_url_from_storage(storage_url: str) -> Optional[str]:
74
+ """
75
+ 从 storage URL 获取 JSON 并提取 getUrl
76
+ """
77
+ async with httpx.AsyncClient() as client:
78
+ try:
79
+ response = await client.get(storage_url)
80
+ if response.status_code != 200:
81
+ print(f"获取 storage URL 失败,状态码: {response.status_code}")
82
+ return None
83
+ json_response = response.json()
84
+ return json_response.get("getUrl")
85
+ except Exception as e:
86
+ print(f"Error fetching getUrl from storage: {e}")
87
+ return None
88
+
89
+ async def download_image(image_url: str) -> Optional[bytes]:
90
+ """
91
+ 下载图像
92
+ """
93
+ async with httpx.AsyncClient() as client:
94
+ try:
95
+ response = await client.get(image_url)
96
+ if response.status_code == 200:
97
+ return response.content
98
+ else:
99
+ print(f"下载图像失败,状态码: {response.status_code}")
100
+ return None
101
+ except Exception as e:
102
+ print(f"Error downloading image: {e}")
103
+ return None
104
+
105
+ def save_base64_image(base64_str: str, images_dir: str = "images") -> str:
106
+ """
107
+ 将Base64编码的图片保存到images目录,返回文件名
108
+ """
109
+ if not os.path.exists(images_dir):
110
+ os.makedirs(images_dir)
111
+ image_data = base64.b64decode(base64_str)
112
+ filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
113
+ file_path = os.path.join(images_dir, filename)
114
+ with open(file_path, "wb") as f:
115
+ f.write(image_data)
116
+ return filename
117
+
118
+ def is_base64_image(url: str) -> bool:
119
+ """
120
+ 判断URL是否为Base64编码的图片
121
+ """
122
+ return url.startswith("data:image/")
123
+
124
+ # 根路径GET请求处理
125
+ @app.get("/", response_class=HTMLResponse)
126
+ async def read_root():
127
+ """返回欢迎页面"""
128
+ html_content = """
129
+ <html>
130
+ <head>
131
+ <title>Welcome to API</title>
132
+ </head>
133
+ <body>
134
+ <h1>Welcome to API</h1>
135
+ <p>This API is used to interact with the ChatGPT model. You can send messages to the model and receive responses.</p>
136
+ </body>
137
+ </html>
138
+ """
139
+ return HTMLResponse(content=html_content, status_code=200)
140
+
141
+ # 聊天完成处理
142
+ @app.post("/v1/chat/completions")
143
+ async def chat_completions(request: Request, background_tasks: BackgroundTasks):
144
+ """
145
+ 处理聊天完成请求
146
+ """
147
+ try:
148
+ request_body = await request.json()
149
+ except json.JSONDecodeError:
150
+ raise HTTPException(status_code=400, detail="Invalid JSON")
151
+
152
+ # 打印接收到的请求
153
+ print("Received Completion JSON:", json.dumps(request_body, ensure_ascii=False))
154
+
155
+ # 处理消息内容
156
+ messages = request_body.get("messages", [])
157
+ temperature = request_body.get("temperature", 1.0)
158
+ top_p = request_body.get("top_p", 1.0)
159
+ max_tokens = request_body.get("max_tokens", 8000)
160
+ model = request_body.get("model", "gpt-4o")
161
+ is_stream = request_body.get("stream", False) # 获取 stream 字段
162
+
163
+ has_image = False
164
+ has_text = False
165
+
166
+ # 清理和提取消息内容
167
+ cleaned_messages = []
168
+ for message in messages:
169
+ content = message.get("content", "")
170
+ if isinstance(content, list):
171
+ text_parts = []
172
+ images = []
173
+ for item in content:
174
+ if "text" in item:
175
+ text_parts.append(item.get("text", ""))
176
+ elif "image_url" in item:
177
+ has_image = True
178
+ image_info = item.get("image_url", {})
179
+ url = image_info.get("url", "")
180
+ if is_base64_image(url):
181
+ # 解码并保存图片
182
+ base64_str = url.split(",")[1]
183
+ filename = save_base64_image(base64_str)
184
+ base_url = app.state.base_url
185
+ image_url = f"{base_url}/images/{filename}"
186
+ images.append({"data": image_url})
187
+ else:
188
+ images.append({"data": url})
189
+ extracted_content = " ".join(text_parts).strip()
190
+ if extracted_content:
191
+ has_text = True
192
+ message["content"] = extracted_content
193
+ if images:
194
+ message["images"] = images
195
+ cleaned_messages.append(message)
196
+ print("Extracted:", extracted_content)
197
+ else:
198
+ if images:
199
+ has_image = True
200
+ message["content"] = ""
201
+ message["images"] = images
202
+ cleaned_messages.append(message)
203
+ print("Extracted image only.")
204
+ else:
205
+ print("Deleted message with empty content.")
206
+ elif isinstance(content, str):
207
+ content_str = content.strip()
208
+ if content_str:
209
+ has_text = True
210
+ message["content"] = content_str
211
+ cleaned_messages.append(message)
212
+ print("Retained content:", content_str)
213
+ else:
214
+ print("Deleted message with empty content.")
215
+ else:
216
+ print("Deleted non-expected type of content message.")
217
+
218
+ if not cleaned_messages:
219
+ raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
220
+
221
+ # 验证模型
222
+ if model not in MODELS:
223
+ model = "gpt-4o"
224
+
225
+ # 构建新的请求JSON
226
+ new_request_json = {
227
+ "function_image_gen": False,
228
+ "function_web_search": True,
229
+ "max_tokens": max_tokens,
230
+ "model": model,
231
+ "source": "chat/free",
232
+ "temperature": temperature,
233
+ "top_p": top_p,
234
+ "messages": cleaned_messages,
235
+ }
236
+
237
+ modified_request_body = json.dumps(new_request_json, ensure_ascii=False)
238
+ print("Modified Request JSON:", modified_request_body)
239
+
240
+ # 获取Bearer Token
241
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body)
242
+ if not tmp_token:
243
+ raise HTTPException(status_code=500, detail="无法生成 Bearer Token")
244
+
245
+ bearer_token, formatted_date = tmp_token
246
+
247
+ headers = {
248
+ "Date": formatted_date,
249
+ "Client-time-zone": "-05:00",
250
+ "Authorization": bearer_token,
251
+ "User-Agent": "ChatOn_Android/1.53.502",
252
+ "Accept-Language": "en-US",
253
+ "X-Cl-Options": "hb",
254
+ "Content-Type": "application/json; charset=UTF-8",
255
+ }
256
+
257
+ if is_stream:
258
+ # 流式响应处理
259
+ async def event_generator():
260
+ async with httpx.AsyncClient(timeout=None) as client_stream:
261
+ try:
262
+ async with client_stream.stream("POST", EXTERNAL_API_URL, headers=headers, content=modified_request_body) as streamed_response:
263
+ async for line in streamed_response.aiter_lines():
264
+ if line.startswith("data: "):
265
+ data = line[6:].strip()
266
+ if data == "[DONE]":
267
+ # 通知客户端流结束
268
+ yield "data: [DONE]\n\n"
269
+ break
270
+ try:
271
+ sse_json = json.loads(data)
272
+ if "choices" in sse_json:
273
+ for choice in sse_json["choices"]:
274
+ delta = choice.get("delta", {})
275
+ content = delta.get("content")
276
+ if content:
277
+ new_sse_json = {
278
+ "choices": [
279
+ {
280
+ "index": choice.get("index", 0),
281
+ "delta": {"content": content},
282
+ }
283
+ ],
284
+ "created": sse_json.get(
285
+ "created", int(datetime.now(timezone.utc).timestamp())
286
+ ),
287
+ "id": sse_json.get(
288
+ "id", str(uuid.uuid4())
289
+ ),
290
+ "model": sse_json.get("model", "gpt-4o"),
291
+ "system_fingerprint": f"fp_{uuid.uuid4().hex[:12]}",
292
+ }
293
+ new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
294
+ yield new_sse_line
295
+ except json.JSONDecodeError:
296
+ print("JSON解析错误")
297
+ continue
298
+ except httpx.RequestError as exc:
299
+ print(f"外部API请求失败: {exc}")
300
+ yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
301
+
302
+ return StreamingResponse(
303
+ event_generator(),
304
+ media_type="text/event-stream",
305
+ headers={
306
+ "Cache-Control": "no-cache",
307
+ "Connection": "keep-alive",
308
+ # CORS头已通过中间件处理,无需在这里重复添加
309
+ },
310
+ )
311
+ else:
312
+ # 非流式响应处理
313
+ async with httpx.AsyncClient(timeout=None) as client:
314
+ try:
315
+ response = await client.post(
316
+ EXTERNAL_API_URL,
317
+ headers=headers,
318
+ content=modified_request_body,
319
+ timeout=None
320
+ )
321
+
322
+ if response.status_code != 200:
323
+ raise HTTPException(
324
+ status_code=response.status_code,
325
+ detail=f"API 错误: {response.status_code}",
326
+ )
327
+
328
+ sse_lines = response.text.splitlines()
329
+ content_builder = ""
330
+ images_urls = []
331
+
332
+ for line in sse_lines:
333
+ if line.startswith("data: "):
334
+ data = line[6:].strip()
335
+ if data == "[DONE]":
336
+ break
337
+ try:
338
+ sse_json = json.loads(data)
339
+ if "choices" in sse_json:
340
+ for choice in sse_json["choices"]:
341
+ if "delta" in choice:
342
+ delta = choice["delta"]
343
+ if "content" in delta:
344
+ content_builder += delta["content"]
345
+ except json.JSONDecodeError:
346
+ print("JSON解析错误")
347
+ continue
348
+
349
+ openai_response = {
350
+ "id": f"chatcmpl-{uuid.uuid4()}",
351
+ "object": "chat.completion",
352
+ "created": int(datetime.now(timezone.utc).timestamp()),
353
+ "model": model,
354
+ "choices": [
355
+ {
356
+ "index": 0,
357
+ "message": {
358
+ "role": "assistant",
359
+ "content": content_builder,
360
+ },
361
+ "finish_reason": "stop",
362
+ }
363
+ ],
364
+ }
365
+
366
+ # 处理图片(如果有)
367
+ if has_image:
368
+ images = []
369
+ for message in cleaned_messages:
370
+ if "images" in message:
371
+ for img in message["images"]:
372
+ images.append({"data": img["data"]})
373
+ openai_response["choices"][0]["message"]["images"] = images
374
+
375
+ return JSONResponse(content=openai_response, status_code=200)
376
+ except httpx.RequestError as exc:
377
+ raise HTTPException(status_code=500, detail=f"请求失败: {str(exc)}")
378
+ except Exception as exc:
379
+ raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
380
+
381
+ # 图像生成处理
382
+ @app.post("/v1/images/generations")
383
+ async def images_generations(request: Request):
384
+ """
385
+ 处理图像生成请求
386
+ """
387
+ try:
388
+ request_body = await request.json()
389
+ except json.JSONDecodeError:
390
+ return send_error_response("Invalid JSON", status_code=400)
391
+
392
+ print("Received Image Generations JSON:", json.dumps(request_body, ensure_ascii=False))
393
+
394
+ # 验证必需的字段
395
+ if "prompt" not in request_body:
396
+ return send_error_response("缺少必需的字段: prompt", status_code=400)
397
+
398
+ user_prompt = request_body.get("prompt", "").strip()
399
+ response_format = request_body.get("response_format", "b64_json").strip()
400
+
401
+ if not user_prompt:
402
+ return send_error_response("Prompt 不能为空。", status_code=400)
403
+
404
+ print(f"Prompt: {user_prompt}")
405
+
406
+ # 构建新的 TextToImage JSON 请求体
407
+ text_to_image_json = {
408
+ "function_image_gen": True,
409
+ "function_web_search": True,
410
+ "image_aspect_ratio": "1:1",
411
+ "image_style": "photographic", # 暂时固定 image_style
412
+ "max_tokens": 8000,
413
+ "messages": [
414
+ {
415
+ "content": "You are a helpful artist, please based on imagination draw a picture.",
416
+ "role": "system"
417
+ },
418
+ {
419
+ "content": "Draw: " + user_prompt,
420
+ "role": "user"
421
+ }
422
+ ],
423
+ "model": "gpt-4o", # 固定 model,只能gpt-4o或gpt-4o-mini
424
+ "source": "chat/pro_image" # 固定 source
425
+ }
426
+
427
+ modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
428
+ print("Modified Request JSON:", modified_request_body)
429
+
430
+ # 获取Bearer Token
431
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
432
+ if not tmp_token:
433
+ return send_error_response("无法生成 Bearer Token", status_code=500)
434
+
435
+ bearer_token, formatted_date = tmp_token
436
+
437
+ headers = {
438
+ "Date": formatted_date,
439
+ "Client-time-zone": "-05:00",
440
+ "Authorization": bearer_token,
441
+ "User-Agent": "ChatOn_Android/1.53.502",
442
+ "Accept-Language": "en-US",
443
+ "X-Cl-Options": "hb",
444
+ "Content-Type": "application/json; charset=UTF-8",
445
+ }
446
+
447
+ async with httpx.AsyncClient(timeout=None) as client:
448
+ try:
449
+ response = await client.post(
450
+ EXTERNAL_API_URL, headers=headers, content=modified_request_body, timeout=None
451
+ )
452
+ if response.status_code != 200:
453
+ return send_error_response(f"API 错误: {response.status_code}", status_code=500)
454
+
455
+ # 初始化用于拼接 URL 的字符串
456
+ url_builder = ""
457
+
458
+ # 读取 SSE 流并拼接 URL
459
+ async for line in response.aiter_lines():
460
+ if line.startswith("data: "):
461
+ data = line[6:].strip()
462
+ if data == "[DONE]":
463
+ break
464
+ try:
465
+ sse_json = json.loads(data)
466
+ if "choices" in sse_json:
467
+ for choice in sse_json["choices"]:
468
+ delta = choice.get("delta", {})
469
+ content = delta.get("content")
470
+ if content:
471
+ url_builder += content
472
+ except json.JSONDecodeError:
473
+ print("JSON解析错误")
474
+ continue
475
+
476
+ image_markdown = url_builder
477
+ # Step 1: 检查Markdown文本是否为空
478
+ if not image_markdown:
479
+ print("无法从 SSE 流中构建图像 Markdown。")
480
+ return send_error_response("无法从 SSE 流中构建图像 Markdown。", status_code=500)
481
+
482
+ # Step 2, 3, 4, 5: 处理图像
483
+ extracted_path = extract_path_from_markdown(image_markdown)
484
+ if not extracted_path:
485
+ print("无法从 Markdown 中提取路径。")
486
+ return send_error_response("无法从 Markdown 中提取路径。", status_code=500)
487
+
488
+ print(f"提取的路径: {extracted_path}")
489
+
490
+ # Step 5: 拼接最终的存储URL
491
+ storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
492
+ print(f"存储URL: {storage_url}")
493
+
494
+ # 获取最终下载URL
495
+ final_download_url = await fetch_get_url_from_storage(storage_url)
496
+ if not final_download_url:
497
+ return send_error_response("无法从 storage URL 获取最终下载链接。", status_code=500)
498
+
499
+ print(f"Final Download URL: {final_download_url}")
500
+
501
+ # 下载图像
502
+ image_bytes = await download_image(final_download_url)
503
+ if not image_bytes:
504
+ return send_error_response("无法从 URL 下载图像。", status_code=500)
505
+
506
+ # 转换为 Base64
507
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
508
+
509
+ # 将图片保存到images目录并构建可访问的URL
510
+ filename = save_base64_image(image_base64)
511
+ base_url = app.state.base_url
512
+ accessible_url = f"{base_url}/images/{filename}"
513
+
514
+ # 根据 response_format 返回相应的响应
515
+ if response_format.lower() == "b64_json":
516
+ response_json = {
517
+ "data": [
518
+ {
519
+ "b64_json": image_base64
520
+ }
521
+ ]
522
+ }
523
+ return JSONResponse(content=response_json, status_code=200)
524
+ else:
525
+ # 构建包含可访问URL的响应
526
+ response_json = {
527
+ "data": [
528
+ {
529
+ "url": accessible_url
530
+ }
531
+ ]
532
+ }
533
+ return JSONResponse(content=response_json, status_code=200)
534
+ except httpx.RequestError as exc:
535
+ print(f"请求失败: {exc}")
536
+ return send_error_response(f"请求失败: {str(exc)}", status_code=500)
537
+ except Exception as exc:
538
+ print(f"内部服务器错误: {exc}")
539
+ return send_error_response(f"内部服务器错误: {str(exc)}", status_code=500)
540
+
541
+ # 运行服务器
542
+ def main():
543
+ parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
544
+ parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
545
+ parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
546
+ args = parser.parse_args()
547
+
548
+ base_url = args.base_url
549
+ port = args.port
550
+
551
+ # 确保 images 目录存在
552
+ if not os.path.exists("images"):
553
+ os.makedirs("images")
554
+
555
+ # 设置 FastAPI 应用的 state
556
+ app.state.base_url = base_url
557
+
558
+ print(f"Server started on port {port} with base_url: {base_url}")
559
+
560
+ # 运行FastAPI应用
561
+ uvicorn.run(app, host="0.0.0.0", port=port)
562
+
563
+ async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
564
+ """查找可用的端口号"""
565
+ for port in range(start_port, end_port + 1):
566
+ try:
567
+ server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
568
+ server.close()
569
+ await server.wait_closed()
570
+ return port
571
+ except OSError:
572
+ continue
573
+ raise RuntimeError(f"No available ports between {start_port} and {end_port}")
574
+
575
+ if __name__ == "__main__":
576
+ main()