Update app.py
Browse files
app.py
CHANGED
|
@@ -39,6 +39,10 @@ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
|
|
| 39 |
if not REPLICATE_API_TOKEN:
|
| 40 |
logger.error("REPLICATE_API_TOKEN not found in environment variables")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Replicate API配置
|
| 43 |
REPLICATE_BASE_URL = "https://api.replicate.com/v1"
|
| 44 |
DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
|
|
@@ -178,10 +182,10 @@ def decode_base64_file(data_url: str) -> tuple[str, str, str]:
|
|
| 178 |
logger.error(f"Failed to parse data URL: {e}")
|
| 179 |
return None, None, None
|
| 180 |
|
| 181 |
-
async def
|
| 182 |
"""
|
| 183 |
-
将 base64
|
| 184 |
-
|
| 185 |
"""
|
| 186 |
try:
|
| 187 |
# 从 base64 data URL 中提取纯 base64 数据
|
|
@@ -190,63 +194,91 @@ async def upload_image_to_temp_service(session: aiohttp.ClientSession, base64_da
|
|
| 190 |
else:
|
| 191 |
base64_content = base64_data
|
| 192 |
|
| 193 |
-
#
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
# 1. 使用临时文件服务(如 imgbb, imgur 等)
|
| 200 |
-
# 2. 使用自己的文件服务器
|
| 201 |
-
# 3. 修改为使用 claude-3.5-sonnet 作为替代
|
| 202 |
-
|
| 203 |
-
return None # 返回 None 表示上传失败
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
except Exception as e:
|
| 206 |
-
logger.error(f"Failed to upload image: {e}")
|
| 207 |
return None
|
| 208 |
|
| 209 |
-
def format_image_for_model(base64_data: str, model_config: Dict[str, Any]) -> str:
|
| 210 |
"""
|
| 211 |
根据模型配置格式化图片数据
|
| 212 |
"""
|
| 213 |
image_format = model_config.get("image_format", "data_url")
|
| 214 |
|
| 215 |
-
if image_format == "
|
| 216 |
-
#
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
#
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
if decoded_bytes.startswith(b'\xff\xd8\xff'):
|
| 226 |
-
# JPEG
|
| 227 |
-
return f"data:image/jpeg;base64,{base64_data}"
|
| 228 |
-
elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'):
|
| 229 |
-
# PNG
|
| 230 |
-
return f"data:image/png;base64,{base64_data}"
|
| 231 |
-
elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'):
|
| 232 |
-
# GIF
|
| 233 |
-
return f"data:image/gif;base64,{base64_data}"
|
| 234 |
-
elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]:
|
| 235 |
-
# WebP
|
| 236 |
-
return f"data:image/webp;base64,{base64_data}"
|
| 237 |
-
else:
|
| 238 |
-
# 默认使用 JPEG
|
| 239 |
-
return f"data:image/jpeg;base64,{base64_data}"
|
| 240 |
-
except Exception as e:
|
| 241 |
-
logger.warning(f"Failed to detect image format: {e}, using JPEG as default")
|
| 242 |
-
return f"data:image/jpeg;base64,{base64_data}"
|
| 243 |
|
| 244 |
-
elif image_format == "
|
| 245 |
-
|
| 246 |
-
return None
|
| 247 |
|
| 248 |
return base64_data
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
|
| 251 |
"""
|
| 252 |
从消息中提取文本内容、图片和文件
|
|
@@ -347,7 +379,7 @@ def format_files_for_prompt(files: List[Dict[str, str]]) -> str:
|
|
| 347 |
|
| 348 |
return "\n".join(file_sections)
|
| 349 |
|
| 350 |
-
async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
|
| 351 |
"""将OpenAI格式的请求转换为Replicate格式"""
|
| 352 |
try:
|
| 353 |
messages = openai_request.get("messages", [])
|
|
@@ -416,14 +448,12 @@ async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_ov
|
|
| 416 |
# 处理图片格式
|
| 417 |
formatted_image = None
|
| 418 |
if has_images and primary_image:
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
else:
|
| 426 |
-
formatted_image = format_image_for_model(primary_image, model_config)
|
| 427 |
|
| 428 |
# 构建 Replicate 格式的输入
|
| 429 |
replicate_input = {}
|
|
@@ -459,7 +489,10 @@ async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_ov
|
|
| 459 |
# 处理图片
|
| 460 |
if formatted_image:
|
| 461 |
replicate_input["image"] = formatted_image
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
# 只在有 system_prompt 时才添加
|
| 465 |
if system_prompt:
|
|
@@ -530,7 +563,10 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
|
|
| 530 |
if "input" in log_data:
|
| 531 |
if "image" in log_data["input"]:
|
| 532 |
image_data = log_data["input"]["image"]
|
| 533 |
-
|
|
|
|
|
|
|
|
|
|
| 534 |
if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
|
| 535 |
log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
|
| 536 |
logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
|
|
@@ -615,13 +651,14 @@ async def root():
|
|
| 615 |
"message": "Replicate API Proxy for LobeChat with Vision and File Support",
|
| 616 |
"status": "running",
|
| 617 |
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
|
| 618 |
-
"
|
|
|
|
| 619 |
"supported_models": list(MODEL_CONFIGS.keys()),
|
| 620 |
"vision_support": True,
|
| 621 |
"file_support": True,
|
| 622 |
"supported_text_files": list(SUPPORTED_TEXT_EXTENSIONS),
|
| 623 |
"supported_image_files": list(SUPPORTED_IMAGE_EXTENSIONS),
|
| 624 |
-
"
|
| 625 |
}
|
| 626 |
|
| 627 |
@app.get("/health")
|
|
@@ -630,6 +667,7 @@ async def health():
|
|
| 630 |
return {
|
| 631 |
"status": "healthy",
|
| 632 |
"replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
|
|
|
|
| 633 |
"timestamp": asyncio.get_event_loop().time(),
|
| 634 |
"model_configs": MODEL_CONFIGS,
|
| 635 |
"supported_file_types": {
|
|
@@ -665,13 +703,13 @@ async def chat_completions(request: Request):
|
|
| 665 |
logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
|
| 666 |
logger.info(f"Message count: {len(body.get('messages', []))}")
|
| 667 |
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
async
|
| 675 |
try:
|
| 676 |
# 创建预测
|
| 677 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
|
@@ -755,21 +793,20 @@ async def chat_completions(request: Request):
|
|
| 755 |
}
|
| 756 |
}
|
| 757 |
yield f"data: {json.dumps(error_response)}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
media_type="text/event-stream",
|
| 762 |
-
headers={
|
| 763 |
-
"Cache-Control": "no-cache",
|
| 764 |
-
"Connection": "keep-alive",
|
| 765 |
-
"Access-Control-Allow-Origin": "*",
|
| 766 |
-
"X-Accel-Buffering": "no",
|
| 767 |
-
}
|
| 768 |
-
)
|
| 769 |
-
|
| 770 |
-
else:
|
| 771 |
-
# 非流式响应
|
| 772 |
-
async with aiohttp.ClientSession() as session:
|
| 773 |
# 创建预测
|
| 774 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
| 775 |
prediction_id = prediction.get('id')
|
|
|
|
| 39 |
if not REPLICATE_API_TOKEN:
|
| 40 |
logger.error("REPLICATE_API_TOKEN not found in environment variables")
|
| 41 |
|
| 42 |
+
# imgbb API 配置
|
| 43 |
+
IMGBB_API_KEY = "78f0c4360135e80c46b24b44e1e20a20"
|
| 44 |
+
IMGBB_API_URL = "https://api.imgbb.com/1/upload"
|
| 45 |
+
|
| 46 |
# Replicate API配置
|
| 47 |
REPLICATE_BASE_URL = "https://api.replicate.com/v1"
|
| 48 |
DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
|
|
|
|
| 182 |
logger.error(f"Failed to parse data URL: {e}")
|
| 183 |
return None, None, None
|
| 184 |
|
| 185 |
+
async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str) -> str:
|
| 186 |
"""
|
| 187 |
+
将 base64 图片上传到 imgbb
|
| 188 |
+
返回图片的 URL
|
| 189 |
"""
|
| 190 |
try:
|
| 191 |
# 从 base64 data URL 中提取纯 base64 数据
|
|
|
|
| 194 |
else:
|
| 195 |
base64_content = base64_data
|
| 196 |
|
| 197 |
+
# 准备上传数据
|
| 198 |
+
data = {
|
| 199 |
+
'key': IMGBB_API_KEY,
|
| 200 |
+
'image': base64_content,
|
| 201 |
+
'expiration': 300 # 5分钟过期,避免永久占用存储
|
| 202 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars")
|
| 205 |
+
|
| 206 |
+
# 上传到 imgbb
|
| 207 |
+
async with session.post(IMGBB_API_URL, data=data, timeout=30) as response:
|
| 208 |
+
if response.status == 200:
|
| 209 |
+
result = await response.json()
|
| 210 |
+
if result.get('success'):
|
| 211 |
+
image_url = result['data']['url']
|
| 212 |
+
logger.info(f"Image uploaded successfully: {image_url}")
|
| 213 |
+
return image_url
|
| 214 |
+
else:
|
| 215 |
+
logger.error(f"imgbb upload failed: {result}")
|
| 216 |
+
return None
|
| 217 |
+
else:
|
| 218 |
+
error_text = await response.text()
|
| 219 |
+
logger.error(f"imgbb upload error: {response.status} - {error_text}")
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
except asyncio.TimeoutError:
|
| 223 |
+
logger.error("Timeout uploading image to imgbb")
|
| 224 |
+
return None
|
| 225 |
except Exception as e:
|
| 226 |
+
logger.error(f"Failed to upload image to imgbb: {e}")
|
| 227 |
return None
|
| 228 |
|
| 229 |
+
async def format_image_for_model(session: aiohttp.ClientSession, base64_data: str, model_config: Dict[str, Any]) -> str:
|
| 230 |
"""
|
| 231 |
根据模型配置格式化图片数据
|
| 232 |
"""
|
| 233 |
image_format = model_config.get("image_format", "data_url")
|
| 234 |
|
| 235 |
+
if image_format == "url":
|
| 236 |
+
# 需要上传图片到 imgbb 并返回 URL
|
| 237 |
+
image_url = await upload_image_to_imgbb(session, base64_data)
|
| 238 |
+
if image_url:
|
| 239 |
+
return image_url
|
| 240 |
+
else:
|
| 241 |
+
logger.error("Failed to upload image, falling back to data URL")
|
| 242 |
+
# 上传失败时降级到 data URL 格式
|
| 243 |
+
return format_image_as_data_url(base64_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
elif image_format == "data_url":
|
| 246 |
+
return format_image_as_data_url(base64_data)
|
|
|
|
| 247 |
|
| 248 |
return base64_data
|
| 249 |
|
| 250 |
+
def format_image_as_data_url(base64_data: str) -> str:
|
| 251 |
+
"""
|
| 252 |
+
将 base64 数据格式化为 data URL
|
| 253 |
+
"""
|
| 254 |
+
# 检查 base64 数据是否已经包含 data URL 前缀
|
| 255 |
+
if base64_data.startswith("data:"):
|
| 256 |
+
return base64_data
|
| 257 |
+
|
| 258 |
+
# 如果没有前缀,添加默认的 JPEG data URL 前缀
|
| 259 |
+
try:
|
| 260 |
+
# 解码 base64 数据的前几个字节来检测格式
|
| 261 |
+
decoded_bytes = base64.b64decode(base64_data[:100])
|
| 262 |
+
|
| 263 |
+
if decoded_bytes.startswith(b'\xff\xd8\xff'):
|
| 264 |
+
# JPEG
|
| 265 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 266 |
+
elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'):
|
| 267 |
+
# PNG
|
| 268 |
+
return f"data:image/png;base64,{base64_data}"
|
| 269 |
+
elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'):
|
| 270 |
+
# GIF
|
| 271 |
+
return f"data:image/gif;base64,{base64_data}"
|
| 272 |
+
elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]:
|
| 273 |
+
# WebP
|
| 274 |
+
return f"data:image/webp;base64,{base64_data}"
|
| 275 |
+
else:
|
| 276 |
+
# 默认使用 JPEG
|
| 277 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 278 |
+
except Exception as e:
|
| 279 |
+
logger.warning(f"Failed to detect image format: {e}, using JPEG as default")
|
| 280 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 281 |
+
|
| 282 |
def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
|
| 283 |
"""
|
| 284 |
从消息中提取文本内容、图片和文件
|
|
|
|
| 379 |
|
| 380 |
return "\n".join(file_sections)
|
| 381 |
|
| 382 |
+
async def transform_openai_to_replicate(session: aiohttp.ClientSession, openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
|
| 383 |
"""将OpenAI格式的请求转换为Replicate格式"""
|
| 384 |
try:
|
| 385 |
messages = openai_request.get("messages", [])
|
|
|
|
| 448 |
# 处理图片格式
|
| 449 |
formatted_image = None
|
| 450 |
if has_images and primary_image:
|
| 451 |
+
logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}")
|
| 452 |
+
formatted_image = await format_image_for_model(session, primary_image, model_config)
|
| 453 |
+
|
| 454 |
+
if not formatted_image:
|
| 455 |
+
logger.error("Failed to format image for model")
|
| 456 |
+
raise HTTPException(status_code=500, detail="Failed to process image")
|
|
|
|
|
|
|
| 457 |
|
| 458 |
# 构建 Replicate 格式的输入
|
| 459 |
replicate_input = {}
|
|
|
|
| 489 |
# 处理图片
|
| 490 |
if formatted_image:
|
| 491 |
replicate_input["image"] = formatted_image
|
| 492 |
+
if formatted_image.startswith("http"):
|
| 493 |
+
logger.info(f"Added image URL to request for model {model}: {formatted_image}")
|
| 494 |
+
else:
|
| 495 |
+
logger.info(f"Added image data to request for model {model}: {formatted_image[:100]}...")
|
| 496 |
|
| 497 |
# 只在有 system_prompt 时才添加
|
| 498 |
if system_prompt:
|
|
|
|
| 563 |
if "input" in log_data:
|
| 564 |
if "image" in log_data["input"]:
|
| 565 |
image_data = log_data["input"]["image"]
|
| 566 |
+
if image_data.startswith("http"):
|
| 567 |
+
log_data["input"]["image"] = f"[IMAGE_URL: {image_data}]"
|
| 568 |
+
else:
|
| 569 |
+
log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]"
|
| 570 |
if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
|
| 571 |
log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
|
| 572 |
logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
|
|
|
|
| 651 |
"message": "Replicate API Proxy for LobeChat with Vision and File Support",
|
| 652 |
"status": "running",
|
| 653 |
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
|
| 654 |
+
"imgbb_token_configured": bool(IMGBB_API_KEY),
|
| 655 |
+
"version": "1.2.0",
|
| 656 |
"supported_models": list(MODEL_CONFIGS.keys()),
|
| 657 |
"vision_support": True,
|
| 658 |
"file_support": True,
|
| 659 |
"supported_text_files": list(SUPPORTED_TEXT_EXTENSIONS),
|
| 660 |
"supported_image_files": list(SUPPORTED_IMAGE_EXTENSIONS),
|
| 661 |
+
"claude4_vision_support": "Full support via imgbb image hosting"
|
| 662 |
}
|
| 663 |
|
| 664 |
@app.get("/health")
|
|
|
|
| 667 |
return {
|
| 668 |
"status": "healthy",
|
| 669 |
"replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
|
| 670 |
+
"imgbb_token": "configured" if IMGBB_API_KEY else "missing",
|
| 671 |
"timestamp": asyncio.get_event_loop().time(),
|
| 672 |
"model_configs": MODEL_CONFIGS,
|
| 673 |
"supported_file_types": {
|
|
|
|
| 703 |
logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
|
| 704 |
logger.info(f"Message count: {len(body.get('messages', []))}")
|
| 705 |
|
| 706 |
+
async with aiohttp.ClientSession() as session:
|
| 707 |
+
# 转换请求格式
|
| 708 |
+
replicate_data, model = await transform_openai_to_replicate(session, body)
|
| 709 |
+
|
| 710 |
+
if body.get("stream", False):
|
| 711 |
+
# 流式响应
|
| 712 |
+
async def generate_stream():
|
| 713 |
try:
|
| 714 |
# 创建预测
|
| 715 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
|
|
|
| 793 |
}
|
| 794 |
}
|
| 795 |
yield f"data: {json.dumps(error_response)}\n\n"
|
| 796 |
+
|
| 797 |
+
return StreamingResponse(
|
| 798 |
+
generate_stream(),
|
| 799 |
+
media_type="text/event-stream",
|
| 800 |
+
headers={
|
| 801 |
+
"Cache-Control": "no-cache",
|
| 802 |
+
"Connection": "keep-alive",
|
| 803 |
+
"Access-Control-Allow-Origin": "*",
|
| 804 |
+
"X-Accel-Buffering": "no",
|
| 805 |
+
}
|
| 806 |
+
)
|
| 807 |
|
| 808 |
+
else:
|
| 809 |
+
# 非流式响应
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
# 创建预测
|
| 811 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
| 812 |
prediction_id = prediction.get('id')
|