nomid2 commited on
Commit
d609f98
·
verified ·
1 Parent(s): a52668f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -81
app.py CHANGED
@@ -39,6 +39,10 @@ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
39
  if not REPLICATE_API_TOKEN:
40
  logger.error("REPLICATE_API_TOKEN not found in environment variables")
41
 
 
 
 
 
42
  # Replicate API配置
43
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
44
  DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
@@ -178,10 +182,10 @@ def decode_base64_file(data_url: str) -> tuple[str, str, str]:
178
  logger.error(f"Failed to parse data URL: {e}")
179
  return None, None, None
180
 
181
- async def upload_image_to_temp_service(session: aiohttp.ClientSession, base64_data: str) -> str:
182
  """
183
- 将 base64 图片上传到临时图片托管服务
184
- 这里使用 imgbb 作为示例,你也可以使用其他服务
185
  """
186
  try:
187
  # 从 base64 data URL 中提取纯 base64 数据
@@ -190,63 +194,91 @@ async def upload_image_to_temp_service(session: aiohttp.ClientSession, base64_da
190
  else:
191
  base64_content = base64_data
192
 
193
- # 使用 imgbb API(需要免费注册获取 API key)
194
- # 这里暂时返回原始 data URL,你需要根据实际情况实现图片上传
195
- logger.warning("Image upload to external service not implemented, using workaround")
196
-
197
- # 临时解决方案:对于 Claude 4,我们需要找到另一种方式
198
- # 可以考虑:
199
- # 1. 使用临时文件服务(如 imgbb, imgur 等)
200
- # 2. 使用自己的文件服务器
201
- # 3. 修改为使用 claude-3.5-sonnet 作为替代
202
-
203
- return None # 返回 None 表示上传失败
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
- logger.error(f"Failed to upload image: {e}")
207
  return None
208
 
209
- def format_image_for_model(base64_data: str, model_config: Dict[str, Any]) -> str:
210
  """
211
  根据模型配置格式化图片数据
212
  """
213
  image_format = model_config.get("image_format", "data_url")
214
 
215
- if image_format == "data_url":
216
- # 检查 base64 数据是否已经包含 data URL 前缀
217
- if base64_data.startswith("data:"):
218
- return base64_data
219
-
220
- # 如果没有前缀,添加默认的 JPEG data URL 前缀
221
- try:
222
- # 解码 base64 数据的前几个字节来检测格式
223
- decoded_bytes = base64.b64decode(base64_data[:100])
224
-
225
- if decoded_bytes.startswith(b'\xff\xd8\xff'):
226
- # JPEG
227
- return f"data:image/jpeg;base64,{base64_data}"
228
- elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'):
229
- # PNG
230
- return f"data:image/png;base64,{base64_data}"
231
- elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'):
232
- # GIF
233
- return f"data:image/gif;base64,{base64_data}"
234
- elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]:
235
- # WebP
236
- return f"data:image/webp;base64,{base64_data}"
237
- else:
238
- # 默认使用 JPEG
239
- return f"data:image/jpeg;base64,{base64_data}"
240
- except Exception as e:
241
- logger.warning(f"Failed to detect image format: {e}, using JPEG as default")
242
- return f"data:image/jpeg;base64,{base64_data}"
243
 
244
- elif image_format == "url":
245
- # 对于需要 URL 的模型,返回 None 表示需要上传
246
- return None
247
 
248
  return base64_data
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
251
  """
252
  从消息中提取文本内容、图片和文件
@@ -347,7 +379,7 @@ def format_files_for_prompt(files: List[Dict[str, str]]) -> str:
347
 
348
  return "\n".join(file_sections)
349
 
350
- async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
351
  """将OpenAI格式的请求转换为Replicate格式"""
352
  try:
353
  messages = openai_request.get("messages", [])
@@ -416,14 +448,12 @@ async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_ov
416
  # 处理图片格式
417
  formatted_image = None
418
  if has_images and primary_image:
419
- if model_config.get("image_format") == "url":
420
- # Claude 4 需要 URL 格式,暂时降级到 Claude 3.5
421
- logger.warning(f"Model {model} requires URL format for images, falling back to claude-3.5-sonnet")
422
- model = "anthropic/claude-3.5-sonnet"
423
- model_config = MODEL_CONFIGS[model]
424
- formatted_image = format_image_for_model(primary_image, model_config)
425
- else:
426
- formatted_image = format_image_for_model(primary_image, model_config)
427
 
428
  # 构建 Replicate 格式的输入
429
  replicate_input = {}
@@ -459,7 +489,10 @@ async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_ov
459
  # 处理图片
460
  if formatted_image:
461
  replicate_input["image"] = formatted_image
462
- logger.info(f"Added image to request for model {model}: {formatted_image[:100]}...")
 
 
 
463
 
464
  # 只在有 system_prompt 时才添加
465
  if system_prompt:
@@ -530,7 +563,10 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
530
  if "input" in log_data:
531
  if "image" in log_data["input"]:
532
  image_data = log_data["input"]["image"]
533
- log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]"
 
 
 
534
  if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
535
  log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
536
  logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
@@ -615,13 +651,14 @@ async def root():
615
  "message": "Replicate API Proxy for LobeChat with Vision and File Support",
616
  "status": "running",
617
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
618
- "version": "1.1.2",
 
619
  "supported_models": list(MODEL_CONFIGS.keys()),
620
  "vision_support": True,
621
  "file_support": True,
622
  "supported_text_files": list(SUPPORTED_TEXT_EXTENSIONS),
623
  "supported_image_files": list(SUPPORTED_IMAGE_EXTENSIONS),
624
- "notes": "Claude 4 Sonnet image support temporarily falls back to Claude 3.5 Sonnet"
625
  }
626
 
627
  @app.get("/health")
@@ -630,6 +667,7 @@ async def health():
630
  return {
631
  "status": "healthy",
632
  "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
 
633
  "timestamp": asyncio.get_event_loop().time(),
634
  "model_configs": MODEL_CONFIGS,
635
  "supported_file_types": {
@@ -665,13 +703,13 @@ async def chat_completions(request: Request):
665
  logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
666
  logger.info(f"Message count: {len(body.get('messages', []))}")
667
 
668
- # 转换请求格式
669
- replicate_data, model = await transform_openai_to_replicate(body)
670
-
671
- if body.get("stream", False):
672
- # 流式响应
673
- async def generate_stream():
674
- async with aiohttp.ClientSession() as session:
675
  try:
676
  # 创建预测
677
  prediction = await create_replicate_prediction(session, model, replicate_data)
@@ -755,21 +793,20 @@ async def chat_completions(request: Request):
755
  }
756
  }
757
  yield f"data: {json.dumps(error_response)}\n\n"
 
 
 
 
 
 
 
 
 
 
 
758
 
759
- return StreamingResponse(
760
- generate_stream(),
761
- media_type="text/event-stream",
762
- headers={
763
- "Cache-Control": "no-cache",
764
- "Connection": "keep-alive",
765
- "Access-Control-Allow-Origin": "*",
766
- "X-Accel-Buffering": "no",
767
- }
768
- )
769
-
770
- else:
771
- # 非流式响应
772
- async with aiohttp.ClientSession() as session:
773
  # 创建预测
774
  prediction = await create_replicate_prediction(session, model, replicate_data)
775
  prediction_id = prediction.get('id')
 
39
  if not REPLICATE_API_TOKEN:
40
  logger.error("REPLICATE_API_TOKEN not found in environment variables")
41
 
42
+ # imgbb API 配置
43
+ IMGBB_API_KEY = "78f0c4360135e80c46b24b44e1e20a20"
44
+ IMGBB_API_URL = "https://api.imgbb.com/1/upload"
45
+
46
  # Replicate API配置
47
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
48
  DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
 
182
  logger.error(f"Failed to parse data URL: {e}")
183
  return None, None, None
184
 
185
+ async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str) -> str:
186
  """
187
+ 将 base64 图片上传到 imgbb
188
+ 返回图片的 URL
189
  """
190
  try:
191
  # 从 base64 data URL 中提取纯 base64 数据
 
194
  else:
195
  base64_content = base64_data
196
 
197
+ # 准备上传数据
198
+ data = {
199
+ 'key': IMGBB_API_KEY,
200
+ 'image': base64_content,
201
+ 'expiration': 300 # 5分钟过期,避免永久占用存储
202
+ }
 
 
 
 
 
203
 
204
+ logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars")
205
+
206
+ # 上传到 imgbb
207
+ async with session.post(IMGBB_API_URL, data=data, timeout=30) as response:
208
+ if response.status == 200:
209
+ result = await response.json()
210
+ if result.get('success'):
211
+ image_url = result['data']['url']
212
+ logger.info(f"Image uploaded successfully: {image_url}")
213
+ return image_url
214
+ else:
215
+ logger.error(f"imgbb upload failed: {result}")
216
+ return None
217
+ else:
218
+ error_text = await response.text()
219
+ logger.error(f"imgbb upload error: {response.status} - {error_text}")
220
+ return None
221
+
222
+ except asyncio.TimeoutError:
223
+ logger.error("Timeout uploading image to imgbb")
224
+ return None
225
  except Exception as e:
226
+ logger.error(f"Failed to upload image to imgbb: {e}")
227
  return None
228
 
229
+ async def format_image_for_model(session: aiohttp.ClientSession, base64_data: str, model_config: Dict[str, Any]) -> str:
230
  """
231
  根据模型配置格式化图片数据
232
  """
233
  image_format = model_config.get("image_format", "data_url")
234
 
235
+ if image_format == "url":
236
+ # 需要上传图片到 imgbb 并返回 URL
237
+ image_url = await upload_image_to_imgbb(session, base64_data)
238
+ if image_url:
239
+ return image_url
240
+ else:
241
+ logger.error("Failed to upload image, falling back to data URL")
242
+ # 上传失败时降级到 data URL 格式
243
+ return format_image_as_data_url(base64_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ elif image_format == "data_url":
246
+ return format_image_as_data_url(base64_data)
 
247
 
248
  return base64_data
249
 
250
+ def format_image_as_data_url(base64_data: str) -> str:
251
+ """
252
+ 将 base64 数据格式化为 data URL
253
+ """
254
+ # 检查 base64 数据是否已经包含 data URL 前缀
255
+ if base64_data.startswith("data:"):
256
+ return base64_data
257
+
258
+ # 如果没有前缀,添加默认的 JPEG data URL 前缀
259
+ try:
260
+ # 解码 base64 数据的前几个字节来检测格式
261
+ decoded_bytes = base64.b64decode(base64_data[:100])
262
+
263
+ if decoded_bytes.startswith(b'\xff\xd8\xff'):
264
+ # JPEG
265
+ return f"data:image/jpeg;base64,{base64_data}"
266
+ elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'):
267
+ # PNG
268
+ return f"data:image/png;base64,{base64_data}"
269
+ elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'):
270
+ # GIF
271
+ return f"data:image/gif;base64,{base64_data}"
272
+ elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]:
273
+ # WebP
274
+ return f"data:image/webp;base64,{base64_data}"
275
+ else:
276
+ # 默认使用 JPEG
277
+ return f"data:image/jpeg;base64,{base64_data}"
278
+ except Exception as e:
279
+ logger.warning(f"Failed to detect image format: {e}, using JPEG as default")
280
+ return f"data:image/jpeg;base64,{base64_data}"
281
+
282
  def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
283
  """
284
  从消息中提取文本内容、图片和文件
 
379
 
380
  return "\n".join(file_sections)
381
 
382
+ async def transform_openai_to_replicate(session: aiohttp.ClientSession, openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
383
  """将OpenAI格式的请求转换为Replicate格式"""
384
  try:
385
  messages = openai_request.get("messages", [])
 
448
  # 处理图片格式
449
  formatted_image = None
450
  if has_images and primary_image:
451
+ logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}")
452
+ formatted_image = await format_image_for_model(session, primary_image, model_config)
453
+
454
+ if not formatted_image:
455
+ logger.error("Failed to format image for model")
456
+ raise HTTPException(status_code=500, detail="Failed to process image")
 
 
457
 
458
  # 构建 Replicate 格式的输入
459
  replicate_input = {}
 
489
  # 处理图片
490
  if formatted_image:
491
  replicate_input["image"] = formatted_image
492
+ if formatted_image.startswith("http"):
493
+ logger.info(f"Added image URL to request for model {model}: {formatted_image}")
494
+ else:
495
+ logger.info(f"Added image data to request for model {model}: {formatted_image[:100]}...")
496
 
497
  # 只在有 system_prompt 时才添加
498
  if system_prompt:
 
563
  if "input" in log_data:
564
  if "image" in log_data["input"]:
565
  image_data = log_data["input"]["image"]
566
+ if image_data.startswith("http"):
567
+ log_data["input"]["image"] = f"[IMAGE_URL: {image_data}]"
568
+ else:
569
+ log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]"
570
  if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
571
  log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
572
  logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
 
651
  "message": "Replicate API Proxy for LobeChat with Vision and File Support",
652
  "status": "running",
653
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
654
+ "imgbb_token_configured": bool(IMGBB_API_KEY),
655
+ "version": "1.2.0",
656
  "supported_models": list(MODEL_CONFIGS.keys()),
657
  "vision_support": True,
658
  "file_support": True,
659
  "supported_text_files": list(SUPPORTED_TEXT_EXTENSIONS),
660
  "supported_image_files": list(SUPPORTED_IMAGE_EXTENSIONS),
661
+ "claude4_vision_support": "Full support via imgbb image hosting"
662
  }
663
 
664
  @app.get("/health")
 
667
  return {
668
  "status": "healthy",
669
  "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
670
+ "imgbb_token": "configured" if IMGBB_API_KEY else "missing",
671
  "timestamp": asyncio.get_event_loop().time(),
672
  "model_configs": MODEL_CONFIGS,
673
  "supported_file_types": {
 
703
  logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
704
  logger.info(f"Message count: {len(body.get('messages', []))}")
705
 
706
+ async with aiohttp.ClientSession() as session:
707
+ # 转换请求格式
708
+ replicate_data, model = await transform_openai_to_replicate(session, body)
709
+
710
+ if body.get("stream", False):
711
+ # 流式响应
712
+ async def generate_stream():
713
  try:
714
  # 创建预测
715
  prediction = await create_replicate_prediction(session, model, replicate_data)
 
793
  }
794
  }
795
  yield f"data: {json.dumps(error_response)}\n\n"
796
+
797
+ return StreamingResponse(
798
+ generate_stream(),
799
+ media_type="text/event-stream",
800
+ headers={
801
+ "Cache-Control": "no-cache",
802
+ "Connection": "keep-alive",
803
+ "Access-Control-Allow-Origin": "*",
804
+ "X-Accel-Buffering": "no",
805
+ }
806
+ )
807
 
808
+ else:
809
+ # 非流式响应
 
 
 
 
 
 
 
 
 
 
 
 
810
  # 创建预测
811
  prediction = await create_replicate_prediction(session, model, replicate_data)
812
  prediction_id = prediction.get('id')