|
import os |
|
import json |
|
import time |
|
import uuid |
|
import httpx |
|
import re |
|
import asyncio |
|
import xml.etree.ElementTree as ET |
|
import logging |
|
import struct |
|
import base64 |
|
import copy |
|
from fastapi import FastAPI, HTTPException, Request, Header, Depends |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel, Field |
|
from typing import List, Optional, Dict, Any, Union |
|
from dotenv import load_dotenv |
|
from json_repair import repair_json |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.WARNING) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
app = FastAPI( |
|
title="Ki2API - Claude Sonnet 4 OpenAI Compatible API", |
|
description="OpenAI-compatible API for Claude Sonnet 4 via AWS CodeWhisperer", |
|
version="3.0.1" |
|
) |
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
KIRO_ACCESS_TOKEN = os.getenv("KIRO_ACCESS_TOKEN") |
|
KIRO_REFRESH_TOKEN = os.getenv("KIRO_REFRESH_TOKEN") |
|
KIRO_BASE_URL = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse" |
|
PROFILE_ARN = "arn:aws:codewhisperer:us-east-1:699475941385:profile/EHGA3GRVQMUK" |
|
|
|
|
|
MODEL_MAP = { |
|
"claude-sonnet-4-20250514": "CLAUDE_SONNET_4_20250514_V1_0", |
|
"claude-3-5-haiku-20241022": "CLAUDE_3_7_SONNET_20250219_V1_0", |
|
} |
|
DEFAULT_MODEL = "claude-sonnet-4-20250514" |
|
|
|
|
|
class ImageUrl(BaseModel): |
|
url: str |
|
detail: Optional[str] = "auto" |
|
|
|
class ContentPart(BaseModel): |
|
type: str |
|
text: Optional[str] = None |
|
image_url: Optional[ImageUrl] = None |
|
|
|
class ToolCall(BaseModel): |
|
id: str |
|
type: str = "function" |
|
function: Dict[str, Any] |
|
class ChatMessage(BaseModel): |
|
role: str |
|
content: Union[str, List[ContentPart], None] |
|
tool_calls: Optional[List[ToolCall]] = None |
|
tool_call_id: Optional[str] = None |
|
|
|
def get_content_text(self) -> str: |
|
"""Extract text content from either string or content parts""" |
|
|
|
if self.content is None: |
|
logger.warning(f"Message with role '{self.role}' has None content") |
|
return "" |
|
|
|
if isinstance(self.content, str): |
|
return self.content |
|
elif isinstance(self.content, list): |
|
text_parts = [] |
|
for part in self.content: |
|
if isinstance(part, dict): |
|
if part.get("type") == "text" and "text" in part: |
|
text_parts.append(part.get("text", "")) |
|
elif part.get("type") == "tool_result" and "content" in part: |
|
text_parts.append(part.get("content", "")) |
|
elif hasattr(part, 'text') and part.text: |
|
text_parts.append(part.text) |
|
return "".join(text_parts) |
|
else: |
|
logger.warning(f"Unexpected content type: {type(self.content)}") |
|
return str(self.content) if self.content else "" |
|
|
|
class Function(BaseModel): |
|
name: str |
|
description: Optional[str] = None |
|
parameters: Optional[Dict[str, Any]] = None |
|
|
|
class Tool(BaseModel): |
|
type: str = "function" |
|
function: Function |
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
model: str |
|
messages: List[ChatMessage] |
|
temperature: Optional[float] = 0.7 |
|
max_tokens: Optional[int] = 4000 |
|
stream: Optional[bool] = False |
|
top_p: Optional[float] = 1.0 |
|
frequency_penalty: Optional[float] = 0.0 |
|
presence_penalty: Optional[float] = 0.0 |
|
stop: Optional[Union[str, List[str]]] = None |
|
user: Optional[str] = None |
|
tools: Optional[List[Tool]] = None |
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto" |
|
|
|
class Usage(BaseModel): |
|
prompt_tokens: int |
|
completion_tokens: int |
|
total_tokens: int |
|
prompt_tokens_details: Optional[Dict[str, int]] = Field(default_factory=lambda: {"cached_tokens": 0}) |
|
completion_tokens_details: Optional[Dict[str, int]] = Field(default_factory=lambda: {"reasoning_tokens": 0}) |
|
|
|
class ResponseMessage(BaseModel): |
|
role: str |
|
content: Optional[str] = None |
|
tool_calls: Optional[List[ToolCall]] = None |
|
|
|
class Choice(BaseModel): |
|
index: int |
|
message: ResponseMessage |
|
logprobs: Optional[Any] = None |
|
finish_reason: str |
|
|
|
class StreamChoice(BaseModel): |
|
index: int |
|
delta: Dict[str, Any] |
|
logprobs: Optional[Any] = None |
|
finish_reason: Optional[str] = None |
|
|
|
class ChatCompletionResponse(BaseModel): |
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") |
|
object: str = "chat.completion" |
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
model: str |
|
system_fingerprint: Optional[str] = "fp_ki2api_v3" |
|
choices: List[Choice] |
|
usage: Usage |
|
|
|
class ChatCompletionStreamResponse(BaseModel): |
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") |
|
object: str = "chat.completion.chunk" |
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
model: str |
|
system_fingerprint: Optional[str] = "fp_ki2api_v3" |
|
choices: List[StreamChoice] |
|
usage: Optional[Usage] = None |
|
|
|
class ErrorResponse(BaseModel): |
|
error: Dict[str, Any] |
|
|
|
|
|
async def verify_api_key(authorization: str = Header(None)): |
|
if not authorization: |
|
raise HTTPException( |
|
status_code=401, |
|
detail={ |
|
"error": { |
|
"message": "You didn't provide an API key.", |
|
"type": "invalid_request_error", |
|
"param": None, |
|
"code": "invalid_api_key" |
|
} |
|
} |
|
) |
|
|
|
if not authorization.startswith("Bearer "): |
|
raise HTTPException( |
|
status_code=401, |
|
detail={ |
|
"error": { |
|
"message": "Invalid API key format. Expected 'Bearer <key>'", |
|
"type": "invalid_request_error", |
|
"param": None, |
|
"code": "invalid_api_key" |
|
} |
|
} |
|
) |
|
|
|
api_key = authorization.replace("Bearer ", "") |
|
if api_key != API_KEY: |
|
raise HTTPException( |
|
status_code=401, |
|
detail={ |
|
"error": { |
|
"message": "Invalid API key provided", |
|
"type": "invalid_request_error", |
|
"param": None, |
|
"code": "invalid_api_key" |
|
} |
|
} |
|
) |
|
return api_key |
|
|
|
|
|
class TokenManager: |
|
def __init__(self): |
|
self.access_token = KIRO_ACCESS_TOKEN |
|
self.refresh_token = KIRO_REFRESH_TOKEN |
|
self.refresh_url = "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken" |
|
self.last_refresh_time = 0 |
|
self.refresh_lock = asyncio.Lock() |
|
|
|
async def refresh_tokens(self): |
|
"""刷新token,使用锁防止并发刷新请求""" |
|
if not self.refresh_token: |
|
logger.error("没有刷新token,无法刷新访问token") |
|
return None |
|
|
|
async with self.refresh_lock: |
|
|
|
current_time = time.time() |
|
if current_time - self.last_refresh_time < 5: |
|
logger.info("最近已刷新token,使用现有token") |
|
return self.access_token |
|
|
|
try: |
|
logger.info("开始刷新token...") |
|
async with httpx.AsyncClient() as client: |
|
response = await client.post( |
|
self.refresh_url, |
|
json={"refreshToken": self.refresh_token}, |
|
timeout=30 |
|
) |
|
response.raise_for_status() |
|
|
|
data = response.json() |
|
if "accessToken" not in data: |
|
logger.error(f"刷新token响应中没有accessToken: {data}") |
|
return None |
|
|
|
self.access_token = data.get("accessToken") |
|
self.last_refresh_time = current_time |
|
logger.info("token刷新成功") |
|
|
|
|
|
os.environ["KIRO_ACCESS_TOKEN"] = self.access_token |
|
|
|
return self.access_token |
|
except Exception as e: |
|
logger.error(f"token刷新失败: {str(e)}") |
|
return None |
|
|
|
def get_token(self): |
|
return self.access_token |
|
|
|
token_manager = TokenManager() |
|
|
|
|
|
def parse_xml_tool_calls(response_text: str) -> Optional[List[ToolCall]]: |
|
"""解析CodeWhisperer返回的XML格式工具调用,转换为OpenAI格式""" |
|
if not response_text: |
|
return None |
|
|
|
tool_calls = [] |
|
|
|
logger.info(f"🔍 开始解析XML工具调用,响应文本长度: {len(response_text)}") |
|
|
|
|
|
tool_use_pattern = r'<tool_use>\s*<tool_name>([^<]+)</tool_name>\s*<tool_parameter_name>([^<]+)</tool_parameter_name>\s*<tool_parameter_value>([^<]*)</tool_parameter_value>\s*</tool_use>' |
|
matches = re.finditer(tool_use_pattern, response_text, re.DOTALL | re.IGNORECASE) |
|
|
|
for match in matches: |
|
function_name = match.group(1).strip() |
|
param_name = match.group(2).strip() |
|
param_value = match.group(3).strip() |
|
|
|
arguments = {param_name: param_value} |
|
tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
|
|
|
tool_call = ToolCall( |
|
id=tool_call_id, |
|
type="function", |
|
function={ |
|
"name": function_name, |
|
"arguments": json.dumps(arguments, ensure_ascii=False) |
|
} |
|
) |
|
tool_calls.append(tool_call) |
|
logger.info(f"✅ 解析到工具调用: {function_name} with {param_name}={param_value}") |
|
|
|
|
|
if not tool_calls: |
|
simple_pattern = r'<tool_name>([^<]+)</tool_name>\s*<tool_parameter_name>([^<]+)</tool_parameter_name>\s*<tool_parameter_value>([^<]*)</tool_parameter_value>' |
|
matches = re.finditer(simple_pattern, response_text, re.DOTALL | re.IGNORECASE) |
|
|
|
for match in matches: |
|
function_name = match.group(1).strip() |
|
param_name = match.group(2).strip() |
|
param_value = match.group(3).strip() |
|
|
|
arguments = {param_name: param_value} |
|
tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
|
|
|
tool_call = ToolCall( |
|
id=tool_call_id, |
|
type="function", |
|
function={ |
|
"name": function_name, |
|
"arguments": json.dumps(arguments, ensure_ascii=False) |
|
} |
|
) |
|
tool_calls.append(tool_call) |
|
logger.info(f"✅ 解析到简单工具调用: {function_name} with {param_name}={param_value}") |
|
|
|
|
|
if not tool_calls: |
|
name_only_pattern = r'<tool_name>([^<]+)</tool_name>' |
|
matches = re.finditer(name_only_pattern, response_text, re.IGNORECASE) |
|
|
|
for match in matches: |
|
function_name = match.group(1).strip() |
|
tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
|
|
|
tool_call = ToolCall( |
|
id=tool_call_id, |
|
type="function", |
|
function={ |
|
"name": function_name, |
|
"arguments": "{}" |
|
} |
|
) |
|
tool_calls.append(tool_call) |
|
logger.info(f"✅ 解析到无参数工具调用: {function_name}") |
|
|
|
if tool_calls: |
|
logger.info(f"🎉 总共解析出 {len(tool_calls)} 个工具调用") |
|
return tool_calls |
|
else: |
|
logger.info("❌ 未发现任何XML格式的工具调用") |
|
return None |
|
|
|
def find_matching_bracket(text: str, start_pos: int) -> int: |
|
"""找到匹配的结束括号位置""" |
|
logger.info(f"🔧 FIND BRACKET: text length={len(text)}, start_pos={start_pos}") |
|
logger.info(f"🔧 FIND BRACKET: First 100 chars: >>>{text[:100]}<<<") |
|
|
|
if not text or start_pos >= len(text) or text[start_pos] != '[': |
|
logger.info(f"🔧 FIND BRACKET: Early return -1, text[start_pos]={text[start_pos] if start_pos < len(text) else 'OOB'}") |
|
return -1 |
|
|
|
bracket_count = 1 |
|
in_string = False |
|
escape_next = False |
|
|
|
logger.info(f"🔧 FIND BRACKET: Starting search from position {start_pos + 1}") |
|
|
|
for i in range(start_pos + 1, len(text)): |
|
char = text[i] |
|
|
|
if escape_next: |
|
escape_next = False |
|
continue |
|
|
|
if char == '\\' and in_string: |
|
escape_next = True |
|
continue |
|
|
|
if char == '"' and not escape_next: |
|
in_string = not in_string |
|
logger.info(f"🔧 FIND BRACKET: Toggle string mode at {i}, in_string={in_string}") |
|
continue |
|
|
|
if not in_string: |
|
if char == '[': |
|
bracket_count += 1 |
|
logger.info(f"🔧 FIND BRACKET: [ at {i}, bracket_count={bracket_count}") |
|
elif char == ']': |
|
bracket_count -= 1 |
|
logger.info(f"🔧 FIND BRACKET: ] at {i}, bracket_count={bracket_count}") |
|
if bracket_count == 0: |
|
logger.info(f"🔧 FIND BRACKET: Found matching ] at position {i}") |
|
logger.info(f"🔧 FIND BRACKET: Complete match: >>>{text[start_pos:i+1]}<<<") |
|
return i |
|
|
|
logger.info(f"🔧 FIND BRACKET: No matching bracket found, returning -1") |
|
logger.info(f"🔧 FIND BRACKET: Final bracket_count={bracket_count}") |
|
return -1 |
|
|
|
def parse_single_tool_call_professional(tool_call_text: str) -> Optional[ToolCall]: |
|
"""专业的工具调用解析器 - 使用json_repair库""" |
|
logger.info(f"🔧 开始解析工具调用文本 (长度: {len(tool_call_text)})") |
|
|
|
|
|
name_pattern = r'\[Called\s+(\w+)\s+with\s+args:' |
|
name_match = re.search(name_pattern, tool_call_text, re.IGNORECASE) |
|
|
|
if not name_match: |
|
logger.warning("⚠️ 无法从文本中提取函数名") |
|
return None |
|
|
|
function_name = name_match.group(1).strip() |
|
logger.info(f"✅ 提取到函数名: {function_name}") |
|
|
|
|
|
|
|
args_start_marker = "with args:" |
|
args_start_pos = tool_call_text.lower().find(args_start_marker.lower()) |
|
if args_start_pos == -1: |
|
logger.error("❌ 找不到 'with args:' 标记") |
|
return None |
|
|
|
|
|
args_start = args_start_pos + len(args_start_marker) |
|
|
|
|
|
args_end = tool_call_text.rfind(']') |
|
if args_end <= args_start: |
|
logger.error("❌ 找不到结束的 ']'") |
|
return None |
|
|
|
|
|
json_candidate = tool_call_text[args_start:args_end].strip() |
|
logger.info(f"📝 提取的JSON候选文本长度: {len(json_candidate)}") |
|
|
|
|
|
try: |
|
|
|
repaired_json = repair_json(json_candidate) |
|
logger.info(f"🔧 JSON修复完成,修复后长度: {len(repaired_json)}") |
|
|
|
|
|
arguments = json.loads(repaired_json) |
|
|
|
|
|
if not isinstance(arguments, dict): |
|
logger.error(f"❌ 解析结果不是字典类型: {type(arguments)}") |
|
return None |
|
|
|
|
|
tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
|
tool_call = ToolCall( |
|
id=tool_call_id, |
|
type="function", |
|
function={ |
|
"name": function_name, |
|
"arguments": json.dumps(arguments, ensure_ascii=False) |
|
} |
|
) |
|
|
|
logger.info(f"✅ 成功创建工具调用: {function_name} (参数键: {list(arguments.keys())})") |
|
return tool_call |
|
|
|
except Exception as e: |
|
logger.error(f"❌ JSON修复/解析失败: {type(e).__name__}: {str(e)}") |
|
|
|
|
|
try: |
|
|
|
first_brace = json_candidate.find('{') |
|
last_brace = json_candidate.rfind('}') |
|
|
|
if first_brace != -1 and last_brace > first_brace: |
|
core_json = json_candidate[first_brace:last_brace + 1] |
|
|
|
|
|
repaired_core = repair_json(core_json) |
|
arguments = json.loads(repaired_core) |
|
|
|
if isinstance(arguments, dict): |
|
tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
|
tool_call = ToolCall( |
|
id=tool_call_id, |
|
type="function", |
|
function={ |
|
"name": function_name, |
|
"arguments": json.dumps(arguments, ensure_ascii=False) |
|
} |
|
) |
|
logger.info(f"✅ 备用方案成功: {function_name}") |
|
return tool_call |
|
|
|
except Exception as backup_error: |
|
logger.error(f"❌ 备用方案也失败了: {backup_error}") |
|
|
|
return None |
|
|
|
def parse_bracket_tool_calls_professional(response_text: str) -> Optional[List[ToolCall]]: |
|
"""专业的批量工具调用解析器""" |
|
if not response_text or "[Called" not in response_text: |
|
logger.info("📭 响应文本中没有工具调用标记") |
|
return None |
|
|
|
tool_calls = [] |
|
errors = [] |
|
|
|
|
|
try: |
|
|
|
call_positions = [] |
|
start = 0 |
|
while True: |
|
pos = response_text.find("[Called", start) |
|
if pos == -1: |
|
break |
|
call_positions.append(pos) |
|
start = pos + 1 |
|
|
|
logger.info(f"🔍 找到 {len(call_positions)} 个潜在的工具调用") |
|
|
|
for i, start_pos in enumerate(call_positions): |
|
|
|
|
|
if i + 1 < len(call_positions): |
|
end_search_limit = call_positions[i + 1] |
|
else: |
|
end_search_limit = len(response_text) |
|
|
|
|
|
segment = response_text[start_pos:end_search_limit] |
|
|
|
|
|
bracket_count = 0 |
|
end_pos = -1 |
|
|
|
for j, char in enumerate(segment): |
|
if char == '[': |
|
bracket_count += 1 |
|
elif char == ']': |
|
bracket_count -= 1 |
|
if bracket_count == 0: |
|
end_pos = start_pos + j |
|
break |
|
|
|
if end_pos == -1: |
|
|
|
last_bracket = segment.rfind(']') |
|
if last_bracket != -1: |
|
end_pos = start_pos + last_bracket |
|
else: |
|
logger.warning(f"⚠️ 工具调用 {i+1} 没有找到结束括号") |
|
continue |
|
|
|
|
|
tool_call_text = response_text[start_pos:end_pos + 1] |
|
logger.info(f"📋 提取工具调用 {i+1}, 长度: {len(tool_call_text)}") |
|
|
|
|
|
parsed_call = parse_single_tool_call_professional(tool_call_text) |
|
if parsed_call: |
|
tool_calls.append(parsed_call) |
|
else: |
|
errors.append(f"工具调用 {i+1} 解析失败") |
|
|
|
except Exception as e: |
|
logger.error(f"❌ 批量解析过程出错: {type(e).__name__}: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
if tool_calls: |
|
logger.info(f"🎉 成功解析 {len(tool_calls)} 个工具调用") |
|
for tc in tool_calls: |
|
logger.info(f" ✓ {tc.function['name']} (ID: {tc.id})") |
|
|
|
if errors: |
|
logger.warning(f"⚠️ 有 {len(errors)} 个解析失败:") |
|
for error in errors: |
|
logger.warning(f" ✗ {error}") |
|
|
|
return tool_calls if tool_calls else None |
|
|
|
|
|
def parse_bracket_tool_calls(response_text: str) -> Optional[List[ToolCall]]: |
|
"""向后兼容的函数名""" |
|
return parse_bracket_tool_calls_professional(response_text) |
|
|
|
def parse_single_tool_call(tool_call_text: str) -> Optional[ToolCall]: |
|
"""向后兼容的函数名""" |
|
return parse_single_tool_call_professional(tool_call_text) |
|
|
|
|
|
def deduplicate_tool_calls(tool_calls: List[Union[Dict, ToolCall]]) -> List[ToolCall]: |
|
"""Deduplicate tool calls based on function name and arguments""" |
|
seen = set() |
|
unique_tool_calls = [] |
|
|
|
for tool_call in tool_calls: |
|
|
|
if isinstance(tool_call, dict): |
|
tc = ToolCall( |
|
id=tool_call.get("id", f"call_{uuid.uuid4().hex[:8]}"), |
|
type=tool_call.get("type", "function"), |
|
function=tool_call.get("function", {}) |
|
) |
|
else: |
|
tc = tool_call |
|
|
|
|
|
key = ( |
|
tc.function.get("name", ""), |
|
tc.function.get("arguments", "") |
|
) |
|
|
|
if key not in seen: |
|
seen.add(key) |
|
unique_tool_calls.append(tc) |
|
else: |
|
logger.info(f"🔄 Skipping duplicate tool call: {tc.function.get('name', 'unknown')}") |
|
|
|
return unique_tool_calls |
|
|
|
def build_codewhisperer_request(request: ChatCompletionRequest): |
|
codewhisperer_model = MODEL_MAP.get(request.model, MODEL_MAP[DEFAULT_MODEL]) |
|
conversation_id = str(uuid.uuid4()) |
|
|
|
|
|
system_prompt = "" |
|
conversation_messages = [] |
|
|
|
for msg in request.messages: |
|
if msg.role == "system": |
|
system_prompt = msg.get_content_text() |
|
elif msg.role in ["user", "assistant", "tool"]: |
|
conversation_messages.append(msg) |
|
|
|
if not conversation_messages: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": { |
|
"message": "No conversation messages found", |
|
"type": "invalid_request_error", |
|
"param": "messages", |
|
"code": "invalid_request" |
|
} |
|
} |
|
) |
|
|
|
|
|
history = [] |
|
|
|
|
|
if len(conversation_messages) > 1: |
|
history_messages = conversation_messages[:-1] |
|
|
|
|
|
processed_messages = [] |
|
i = 0 |
|
while i < len(history_messages): |
|
msg = history_messages[i] |
|
|
|
if msg.role == "user": |
|
content = msg.get_content_text() or "Continue" |
|
processed_messages.append(("user", content)) |
|
i += 1 |
|
elif msg.role == "assistant": |
|
|
|
if hasattr(msg, 'tool_calls') and msg.tool_calls: |
|
|
|
tool_descriptions = [] |
|
for tc in msg.tool_calls: |
|
func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" |
|
args = tc.function.get("arguments", "{}") if isinstance(tc.function, dict) else "{}" |
|
tool_descriptions.append(f"[Called {func_name} with args: {args}]") |
|
content = " ".join(tool_descriptions) |
|
logger.info(f"📌 Processing assistant message with tool calls: {content}") |
|
else: |
|
content = msg.get_content_text() or "I understand." |
|
processed_messages.append(("assistant", content)) |
|
i += 1 |
|
elif msg.role == "tool": |
|
|
|
tool_content = msg.get_content_text() or "[Tool executed]" |
|
tool_call_id = getattr(msg, 'tool_call_id', 'unknown') |
|
|
|
|
|
formatted_tool_result = f"[Tool result for {tool_call_id}]: {tool_content}" |
|
|
|
|
|
if i + 1 < len(history_messages) and history_messages[i + 1].role == "user": |
|
user_content = history_messages[i + 1].get_content_text() or "" |
|
combined_content = f"{formatted_tool_result}\n{user_content}".strip() |
|
processed_messages.append(("user", combined_content)) |
|
i += 2 |
|
else: |
|
|
|
processed_messages.append(("user", formatted_tool_result)) |
|
i += 1 |
|
else: |
|
i += 1 |
|
|
|
|
|
i = 0 |
|
while i < len(processed_messages): |
|
role, content = processed_messages[i] |
|
|
|
if role == "user": |
|
history.append({ |
|
"userInputMessage": { |
|
"content": content, |
|
"modelId": codewhisperer_model, |
|
"origin": "AI_EDITOR" |
|
} |
|
}) |
|
|
|
|
|
if i + 1 < len(processed_messages) and processed_messages[i + 1][0] == "assistant": |
|
_, assistant_content = processed_messages[i + 1] |
|
history.append({ |
|
"assistantResponseMessage": { |
|
"content": assistant_content |
|
} |
|
}) |
|
i += 2 |
|
else: |
|
|
|
history.append({ |
|
"assistantResponseMessage": { |
|
"content": "I understand." |
|
} |
|
}) |
|
i += 1 |
|
elif role == "assistant": |
|
|
|
history.append({ |
|
"userInputMessage": { |
|
"content": "Continue", |
|
"modelId": codewhisperer_model, |
|
"origin": "AI_EDITOR" |
|
} |
|
}) |
|
history.append({ |
|
"assistantResponseMessage": { |
|
"content": content |
|
} |
|
}) |
|
i += 1 |
|
else: |
|
i += 1 |
|
|
|
|
|
current_message = conversation_messages[-1] |
|
|
|
|
|
images = [] |
|
if isinstance(current_message.content, list): |
|
for part in current_message.content: |
|
if part.type == "image_url" and part.image_url: |
|
try: |
|
|
|
logger.info(f"🔍 处理图片 URL: {part.image_url.url[:50]}...") |
|
|
|
|
|
if not part.image_url.url.startswith("data:image/"): |
|
logger.error(f"❌ 图片 URL 格式不正确,应该以 'data:image/' 开头") |
|
continue |
|
|
|
|
|
|
|
header, encoded_data = part.image_url.url.split(",", 1) |
|
|
|
|
|
|
|
|
|
match = re.search(r'image/(\w+)', header) |
|
if match: |
|
image_format = match.group(1) |
|
|
|
try: |
|
base64.b64decode(encoded_data) |
|
logger.info("✅ Base64 编码验证通过") |
|
except Exception as e: |
|
logger.error(f"❌ Base64 编码无效: {e}") |
|
continue |
|
|
|
images.append({ |
|
"format": image_format, |
|
"source": {"bytes": encoded_data} |
|
}) |
|
logger.info(f"🖼️ 成功处理图片,格式: {image_format}, 大小: {len(encoded_data)} 字符") |
|
else: |
|
logger.warning(f"⚠️ 无法从头部确定图片格式: {header}") |
|
except Exception as e: |
|
logger.error(f"❌ 处理图片 URL 失败: {str(e)}") |
|
|
|
current_content = current_message.get_content_text() |
|
|
|
|
|
if current_message.role == "tool": |
|
|
|
tool_result = current_content or '[Tool executed]' |
|
tool_call_id = getattr(current_message, 'tool_call_id', 'unknown') |
|
current_content = f"[Tool execution completed for {tool_call_id}]: {tool_result}" |
|
|
|
|
|
if len(conversation_messages) > 1: |
|
prev_message = conversation_messages[-2] |
|
if prev_message.role == "assistant" and hasattr(prev_message, 'tool_calls') and prev_message.tool_calls: |
|
|
|
for tc in prev_message.tool_calls: |
|
if tc.id == tool_call_id: |
|
func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" |
|
current_content = f"[Completed execution of {func_name}]: {tool_result}" |
|
break |
|
elif current_message.role == "assistant": |
|
|
|
if hasattr(current_message, 'tool_calls') and current_message.tool_calls: |
|
tool_descriptions = [] |
|
for tc in current_message.tool_calls: |
|
func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" |
|
tool_descriptions.append(f"Continue after calling {func_name}") |
|
current_content = "; ".join(tool_descriptions) |
|
else: |
|
current_content = "Continue the conversation" |
|
|
|
|
|
if not current_content: |
|
current_content = "Continue" |
|
|
|
|
|
if system_prompt: |
|
current_content = f"{system_prompt}\n\n{current_content}" |
|
|
|
|
|
codewhisperer_request = { |
|
"profileArn": PROFILE_ARN, |
|
"conversationState": { |
|
"chatTriggerType": "MANUAL", |
|
"conversationId": conversation_id, |
|
"currentMessage": { |
|
"userInputMessage": { |
|
"content": current_content, |
|
"modelId": codewhisperer_model, |
|
"origin": "AI_EDITOR" |
|
} |
|
}, |
|
"history": history |
|
} |
|
} |
|
|
|
|
|
user_input_message_context = {} |
|
if request.tools: |
|
user_input_message_context["tools"] = [ |
|
{ |
|
"toolSpecification": { |
|
"name": tool.function.name, |
|
"description": tool.function.description or "", |
|
"inputSchema": {"json": tool.function.parameters or {}} |
|
} |
|
} for tool in request.tools |
|
] |
|
|
|
|
|
if images: |
|
|
|
codewhisperer_request["conversationState"]["currentMessage"]["userInputMessage"]["images"] = images |
|
logger.info(f"📊 添加了 {len(images)} 个图片到 userInputMessage 中") |
|
for i, img in enumerate(images): |
|
logger.info(f" - 图片 {i+1}: 格式={img['format']}, 大小={len(img['source']['bytes'])} 字符") |
|
|
|
logger.info(f" - 图片数据前20字符: {img['source']['bytes'][:20]}...") |
|
logger.info(f"✅ 成功添加 images 到 userInputMessage 中") |
|
|
|
if user_input_message_context: |
|
codewhisperer_request["conversationState"]["currentMessage"]["userInputMessage"]["userInputMessageContext"] = user_input_message_context |
|
logger.info(f"✅ 成功添加 userInputMessageContext 到请求中") |
|
|
|
|
|
log_request = copy.deepcopy(codewhisperer_request) |
|
|
|
if "images" in log_request.get("conversationState", {}).get("currentMessage", {}).get("userInputMessage", {}): |
|
for img in log_request["conversationState"]["currentMessage"]["userInputMessage"]["images"]: |
|
if "bytes" in img.get("source", {}): |
|
img["source"]["bytes"] = img["source"]["bytes"][:20] + "..." |
|
|
|
logger.info(f"🔄 COMPLETE CODEWHISPERER REQUEST: {json.dumps(log_request, indent=2)}") |
|
return codewhisperer_request |
|
|
|
class CodeWhispererStreamParser: |
|
def __init__(self): |
|
self.buffer = b'' |
|
self.error_count = 0 |
|
self.max_errors = 5 |
|
|
|
def parse(self, chunk: bytes) -> List[Dict[str, Any]]: |
|
"""解析AWS事件流格式的数据块""" |
|
self.buffer += chunk |
|
logger.debug(f"Parser received {len(chunk)} bytes. Buffer size: {len(self.buffer)}") |
|
events = [] |
|
|
|
if len(self.buffer) < 12: |
|
return [] |
|
|
|
while len(self.buffer) >= 12: |
|
try: |
|
header_bytes = self.buffer[0:8] |
|
total_len, header_len = struct.unpack('>II', header_bytes) |
|
|
|
|
|
if total_len > 2000000 or header_len > 2000000: |
|
logger.error(f"Unreasonable header values: total_len={total_len}, header_len={header_len}") |
|
self.buffer = self.buffer[8:] |
|
self.error_count += 1 |
|
if self.error_count > self.max_errors: |
|
logger.error("Too many parsing errors, clearing buffer") |
|
self.buffer = b'' |
|
continue |
|
|
|
|
|
if len(self.buffer) < total_len: |
|
break |
|
|
|
|
|
frame = self.buffer[:total_len] |
|
self.buffer = self.buffer[total_len:] |
|
|
|
|
|
payload_start = 8 + header_len |
|
payload_end = total_len - 4 |
|
|
|
if payload_start >= payload_end or payload_end > len(frame): |
|
logger.error(f"Invalid payload bounds") |
|
continue |
|
|
|
payload = frame[payload_start:payload_end] |
|
|
|
|
|
try: |
|
payload_str = payload.decode('utf-8', errors='ignore') |
|
|
|
|
|
json_start_index = payload_str.find('{') |
|
if json_start_index != -1: |
|
json_payload = payload_str[json_start_index:] |
|
event_data = json.loads(json_payload) |
|
events.append(event_data) |
|
logger.debug(f"Successfully parsed event: {event_data}") |
|
except json.JSONDecodeError as e: |
|
logger.error(f"JSON decode error: {e}") |
|
continue |
|
|
|
except struct.error as e: |
|
logger.error(f"Struct unpack error: {e}") |
|
self.buffer = self.buffer[1:] |
|
self.error_count += 1 |
|
if self.error_count > self.max_errors: |
|
logger.error("Too many parsing errors, clearing buffer") |
|
self.buffer = b'' |
|
except Exception as e: |
|
logger.error(f"Unexpected error during parsing: {str(e)}") |
|
self.buffer = self.buffer[1:] |
|
self.error_count += 1 |
|
if self.error_count > self.max_errors: |
|
logger.error("Too many parsing errors, clearing buffer") |
|
self.buffer = b'' |
|
|
|
if events: |
|
self.error_count = 0 |
|
|
|
return events |
|
|
|
|
|
class SimpleResponseParser: |
|
@staticmethod |
|
def parse_event_stream_to_json(raw_data: bytes) -> Dict[str, Any]: |
|
"""Simple parser for fallback (from version 1)""" |
|
try: |
|
if isinstance(raw_data, bytes): |
|
raw_str = raw_data.decode('utf-8', errors='ignore') |
|
else: |
|
raw_str = str(raw_data) |
|
|
|
|
|
json_pattern = r'\{[^{}]*"content"[^{}]*\}' |
|
matches = re.findall(json_pattern, raw_str, re.DOTALL) |
|
|
|
if matches: |
|
content_parts = [] |
|
for match in matches: |
|
try: |
|
data = json.loads(match) |
|
if 'content' in data and data['content']: |
|
content_parts.append(data['content']) |
|
except json.JSONDecodeError: |
|
continue |
|
if content_parts: |
|
full_content = ''.join(content_parts) |
|
return { |
|
"content": full_content, |
|
"tokens": len(full_content.split()) |
|
} |
|
|
|
|
|
clean_text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', raw_str) |
|
clean_text = re.sub(r':event-type[^:]*:[^:]*:[^:]*:', '', clean_text) |
|
clean_text = re.sub(r':content-type[^:]*:[^:]*:[^:]*:', '', clean_text) |
|
|
|
meaningful_text = re.sub(r'[^\w\s\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff.,!?;:()"\'-]', '', clean_text) |
|
meaningful_text = re.sub(r'\s+', ' ', meaningful_text).strip() |
|
|
|
if meaningful_text and len(meaningful_text) > 5: |
|
return { |
|
"content": meaningful_text, |
|
"tokens": len(meaningful_text.split()) |
|
} |
|
|
|
return {"content": "No readable content found", "tokens": 0} |
|
|
|
except Exception as e: |
|
return {"content": f"Error parsing response: {str(e)}", "tokens": 0} |
|
|
|
|
|
async def call_kiro_api(request: ChatCompletionRequest): |
|
"""Make API call to Kiro/CodeWhisperer with token refresh handling""" |
|
token = token_manager.get_token() |
|
if not token: |
|
raise HTTPException( |
|
status_code=401, |
|
detail={ |
|
"error": { |
|
"message": "No access token available", |
|
"type": "authentication_error", |
|
"param": None, |
|
"code": "invalid_api_key" |
|
} |
|
} |
|
) |
|
|
|
request_data = build_codewhisperer_request(request) |
|
|
|
headers = { |
|
"Authorization": f"Bearer {token}", |
|
"Content-Type": "application/json", |
|
"Accept": "text/event-stream" if request.stream else "application/json" |
|
} |
|
|
|
try: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.post( |
|
KIRO_BASE_URL, |
|
headers=headers, |
|
json=request_data, |
|
timeout=120 |
|
) |
|
|
|
logger.info(f"📤 RESPONSE STATUS: {response.status_code}") |
|
|
|
if response.status_code == 403: |
|
logger.info("收到403响应,尝试刷新token...") |
|
new_token = await token_manager.refresh_tokens() |
|
if new_token: |
|
headers["Authorization"] = f"Bearer {new_token}" |
|
response = await client.post( |
|
KIRO_BASE_URL, |
|
headers=headers, |
|
json=request_data, |
|
timeout=120 |
|
) |
|
logger.info(f"📤 RETRY RESPONSE STATUS: {response.status_code}") |
|
else: |
|
raise HTTPException(status_code=401, detail="Token refresh failed") |
|
|
|
if response.status_code == 429: |
|
raise HTTPException( |
|
status_code=429, |
|
detail={ |
|
"error": { |
|
"message": "Rate limit exceeded", |
|
"type": "rate_limit_error", |
|
"param": None, |
|
"code": "rate_limit_exceeded" |
|
} |
|
} |
|
) |
|
|
|
response.raise_for_status() |
|
return response |
|
|
|
except httpx.HTTPStatusError as e: |
|
logger.error(f"HTTP ERROR: {e.response.status_code} - {e.response.text}") |
|
raise HTTPException( |
|
status_code=503, |
|
detail={ |
|
"error": { |
|
"message": f"API call failed: {str(e)}", |
|
"type": "api_error", |
|
"param": None, |
|
"code": "api_error" |
|
} |
|
} |
|
) |
|
except Exception as e: |
|
logger.error(f"API call failed: {str(e)}") |
|
raise HTTPException( |
|
status_code=503, |
|
detail={ |
|
"error": { |
|
"message": f"API call failed: {str(e)}", |
|
"type": "api_error", |
|
"param": None, |
|
"code": "api_error" |
|
} |
|
} |
|
) |
|
|
|
|
|
def estimate_tokens(text: str) -> int: |
|
"""Rough token estimation""" |
|
return max(1, len(text) // 4) |
|
|
|
def create_usage_stats(prompt_text: str, completion_text: str) -> Usage: |
|
"""Create usage statistics""" |
|
prompt_tokens = estimate_tokens(prompt_text) |
|
completion_tokens = estimate_tokens(completion_text) |
|
return Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens |
|
) |
|
|
|
|
|
@app.get("/v1/models") |
|
async def list_models(api_key: str = Depends(verify_api_key)): |
|
"""List available models""" |
|
return { |
|
"object": "list", |
|
"data": [ |
|
{ |
|
"id": model_id, |
|
"object": "model", |
|
"created": int(time.time()), |
|
"owned_by": "ki2api" |
|
} |
|
for model_id in MODEL_MAP.keys() |
|
] |
|
} |
|
|
|
@app.post("/v1/chat/completions") |
|
async def create_chat_completion( |
|
request: ChatCompletionRequest, |
|
api_key: str = Depends(verify_api_key) |
|
): |
|
"""Create a chat completion""" |
|
logger.info(f"📥 COMPLETE REQUEST: {request.model_dump_json(indent=2)}") |
|
|
|
|
|
for i, msg in enumerate(request.messages): |
|
if msg.content is None and msg.role != "assistant": |
|
logger.warning(f"Message {i} with role '{msg.role}' has None content") |
|
|
|
if request.model not in MODEL_MAP: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": { |
|
"message": f"The model '{request.model}' does not exist or you do not have access to it.", |
|
"type": "invalid_request_error", |
|
"param": "model", |
|
"code": "model_not_found" |
|
} |
|
} |
|
) |
|
|
|
|
|
response = await create_non_streaming_response(request) |
|
|
|
if request.stream: |
|
|
|
return await convert_to_streaming_response(response) |
|
else: |
|
return response |
|
|
|
|
|
async def convert_to_streaming_response(response: ChatCompletionResponse): |
|
"""将非流式响应转换为流式格式返回""" |
|
async def generate_stream(): |
|
|
|
response_id = response.id |
|
created = response.created |
|
model = response.model |
|
|
|
|
|
initial_chunk = ChatCompletionStreamResponse( |
|
id=response_id, |
|
model=model, |
|
created=created, |
|
choices=[StreamChoice( |
|
index=0, |
|
delta={"role": "assistant"}, |
|
finish_reason=None |
|
)] |
|
) |
|
yield f"data: {initial_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
if response.choices and len(response.choices) > 0: |
|
message = response.choices[0].message |
|
|
|
|
|
if message.tool_calls: |
|
for i, tool_call in enumerate(message.tool_calls): |
|
|
|
tool_chunk = ChatCompletionStreamResponse( |
|
id=response_id, |
|
model=model, |
|
created=created, |
|
choices=[StreamChoice( |
|
index=0, |
|
delta={ |
|
"tool_calls": [{ |
|
"index": i, |
|
"id": tool_call.id, |
|
"type": tool_call.type, |
|
"function": tool_call.function |
|
}] |
|
}, |
|
finish_reason=None |
|
)] |
|
) |
|
yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
elif message.content: |
|
|
|
content = message.content |
|
chunk_size = 50 |
|
|
|
for i in range(0, len(content), chunk_size): |
|
chunk_text = content[i:i + chunk_size] |
|
content_chunk = ChatCompletionStreamResponse( |
|
id=response_id, |
|
model=model, |
|
created=created, |
|
choices=[StreamChoice( |
|
index=0, |
|
delta={"content": chunk_text}, |
|
finish_reason=None |
|
)] |
|
) |
|
yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
|
finish_reason = response.choices[0].finish_reason |
|
end_chunk = ChatCompletionStreamResponse( |
|
id=response_id, |
|
model=model, |
|
created=created, |
|
choices=[StreamChoice( |
|
index=0, |
|
delta={}, |
|
finish_reason=finish_reason |
|
)] |
|
) |
|
yield f"data: {end_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
return StreamingResponse( |
|
generate_stream(), |
|
media_type="text/event-stream", |
|
headers={ |
|
"Cache-Control": "no-cache", |
|
"Connection": "keep-alive", |
|
"Content-Type": "text/event-stream" |
|
} |
|
) |
|
|
|
async def create_non_streaming_response(request: ChatCompletionRequest): |
|
""" |
|
Handles non-streaming chat completion requests. |
|
It fetches the complete response from CodeWhisperer, parses it using |
|
CodeWhispererStreamParser, and constructs a single OpenAI-compatible |
|
ChatCompletionResponse. This version correctly handles tool calls by |
|
parsing both structured event data and bracket format in text. |
|
""" |
|
try: |
|
logger.info("🚀 开始非流式响应生成...") |
|
response = await call_kiro_api(request) |
|
|
|
|
|
logger.info(f"📤 CodeWhisperer响应状态码: {response.status_code}") |
|
logger.info(f"📤 响应头: {dict(response.headers)}") |
|
logger.info(f"📤 原始响应体长度: {len(response.content)} bytes") |
|
|
|
|
|
raw_response_text = "" |
|
try: |
|
raw_response_text = response.content.decode('utf-8', errors='ignore') |
|
logger.info(f"🔍 原始响应文本长度: {len(raw_response_text)}") |
|
logger.info(f"🔍 原始响应预览(前1000字符): {raw_response_text[:1000]}") |
|
|
|
|
|
if "[Called" in raw_response_text: |
|
logger.info("✅ 原始响应中发现 [Called 标记") |
|
called_positions = [m.start() for m in re.finditer(r'\[Called', raw_response_text)] |
|
logger.info(f"🎯 [Called 出现位置: {called_positions}") |
|
else: |
|
logger.info("❌ 原始响应中未发现 [Called 标记") |
|
|
|
except Exception as e: |
|
logger.error(f"❌ 解码原始响应失败: {e}") |
|
|
|
|
|
parser = CodeWhispererStreamParser() |
|
events = parser.parse(response.content) |
|
|
|
full_response_text = "" |
|
tool_calls = [] |
|
current_tool_call_dict = None |
|
|
|
logger.info(f"🔄 解析到 {len(events)} 个事件,开始处理...") |
|
|
|
|
|
for i, event in enumerate(events): |
|
logger.info(f"📋 事件 {i}: {event}") |
|
|
|
for event in events: |
|
|
|
if "name" in event and "toolUseId" in event: |
|
logger.info(f"🔧 发现结构化工具调用事件: {event}") |
|
|
|
if not current_tool_call_dict: |
|
current_tool_call_dict = { |
|
"id": event.get("toolUseId"), |
|
"type": "function", |
|
"function": { |
|
"name": event.get("name"), |
|
"arguments": "" |
|
} |
|
} |
|
logger.info(f"🆕 开始解析工具调用: {current_tool_call_dict['function']['name']}") |
|
|
|
|
|
if "input" in event: |
|
current_tool_call_dict["function"]["arguments"] += event.get("input", "") |
|
logger.info(f"📝 累积参数: {event.get('input', '')}") |
|
|
|
|
|
if event.get("stop"): |
|
logger.info(f"✅ 完成工具调用: {current_tool_call_dict['function']['name']}") |
|
|
|
try: |
|
args = json.loads(current_tool_call_dict["function"]["arguments"]) |
|
current_tool_call_dict["function"]["arguments"] = json.dumps(args, ensure_ascii=False) |
|
logger.info(f"✅ 工具调用参数验证成功") |
|
except json.JSONDecodeError as e: |
|
logger.warning(f"⚠️ 工具调用的参数不是有效的JSON: {current_tool_call_dict['function']['arguments']}") |
|
logger.warning(f"⚠️ JSON错误: {e}") |
|
|
|
tool_calls.append(ToolCall(**current_tool_call_dict)) |
|
current_tool_call_dict = None |
|
|
|
|
|
elif "content" in event: |
|
content = event.get("content", "") |
|
full_response_text += content |
|
logger.info(f"📄 添加文本内容: {content[:100]}...") |
|
|
|
|
|
if current_tool_call_dict: |
|
logger.warning("⚠️ 响应流在工具调用结束前终止,仍尝试添加。") |
|
tool_calls.append(ToolCall(**current_tool_call_dict)) |
|
|
|
logger.info(f"📊 事件处理完成 - 文本长度: {len(full_response_text)}, 结构化工具调用: {len(tool_calls)}") |
|
|
|
|
|
logger.info("🔍 开始检查解析后文本中的bracket格式工具调用...") |
|
bracket_tool_calls = parse_bracket_tool_calls(full_response_text) |
|
if bracket_tool_calls: |
|
logger.info(f"✅ 在解析后文本中发现 {len(bracket_tool_calls)} 个 bracket 格式工具调用") |
|
tool_calls.extend(bracket_tool_calls) |
|
|
|
|
|
for tc in bracket_tool_calls: |
|
|
|
func_name = tc.function.get("name", "unknown") |
|
|
|
escaped_name = re.escape(func_name) |
|
|
|
pattern = r'\[Called\s+' + escaped_name + r'\s+with\s+args:\s*\{[^}]*(?:\{[^}]*\}[^}]*)*\}\s*\]' |
|
full_response_text = re.sub(pattern, '', full_response_text, flags=re.DOTALL) |
|
|
|
|
|
full_response_text = re.sub(r'\s+', ' ', full_response_text).strip() |
|
|
|
|
|
logger.info("🔍 开始检查原始响应中的bracket格式工具调用...") |
|
raw_bracket_tool_calls = parse_bracket_tool_calls(raw_response_text) |
|
if raw_bracket_tool_calls and isinstance(raw_bracket_tool_calls, list): |
|
logger.info(f"✅ 在原始响应中发现 {len(raw_bracket_tool_calls)} 个 bracket 格式工具调用") |
|
tool_calls.extend(raw_bracket_tool_calls) |
|
else: |
|
logger.info("❌ 原始响应中未发现bracket格式工具调用") |
|
|
|
|
|
logger.info(f"🔄 去重前工具调用数量: {len(tool_calls)}") |
|
unique_tool_calls = deduplicate_tool_calls(tool_calls) |
|
logger.info(f"🔄 去重后工具调用数量: {len(unique_tool_calls)}") |
|
|
|
|
|
if unique_tool_calls: |
|
logger.info(f"🔧 构建工具调用响应,包含 {len(unique_tool_calls)} 个工具调用") |
|
for i, tc in enumerate(unique_tool_calls): |
|
logger.info(f"🔧 工具调用 {i}: {tc.function.get('name', 'unknown')}") |
|
|
|
response_message = ResponseMessage( |
|
role="assistant", |
|
content=None, |
|
tool_calls=unique_tool_calls |
|
) |
|
finish_reason = "tool_calls" |
|
else: |
|
logger.info("📄 构建普通文本响应") |
|
|
|
content = full_response_text.strip() if full_response_text.strip() else "I understand." |
|
logger.info(f"📄 最终文本内容: {content[:200]}...") |
|
|
|
response_message = ResponseMessage( |
|
role="assistant", |
|
content=content |
|
) |
|
finish_reason = "stop" |
|
|
|
choice = Choice( |
|
index=0, |
|
message=response_message, |
|
finish_reason=finish_reason |
|
) |
|
|
|
usage = create_usage_stats( |
|
prompt_text=" ".join([msg.get_content_text() for msg in request.messages]), |
|
completion_text=full_response_text if not unique_tool_calls else "" |
|
) |
|
|
|
chat_response = ChatCompletionResponse( |
|
model=request.model, |
|
choices=[choice], |
|
usage=usage |
|
) |
|
|
|
logger.info(f"📤 最终非流式响应构建完成") |
|
logger.info(f"📤 响应类型: {'工具调用' if unique_tool_calls else '文本内容'}") |
|
logger.info(f"📤 完整响应: {chat_response.model_dump_json(indent=2, exclude_none=True)}") |
|
return chat_response |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"❌ 非流式响应处理出错: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise HTTPException( |
|
status_code=500, |
|
detail={ |
|
"error": { |
|
"message": f"Internal server error: {str(e)}", |
|
"type": "internal_server_error", |
|
"param": None, |
|
"code": "internal_error" |
|
} |
|
} |
|
) |
|
|
|
async def create_streaming_response(request: ChatCompletionRequest): |
|
""" |
|
Handles streaming chat completion requests. |
|
This function iteratively processes the binary event stream from CodeWhisperer, |
|
parsing events on the fly. It maintains state to correctly identify and |
|
stream text content or tool calls in the OpenAI-compatible format. |
|
""" |
|
try: |
|
logger.info("开始流式响应生成...") |
|
response = await call_kiro_api(request) |
|
|
|
async def generate_stream(): |
|
response_id = f"chatcmpl-{uuid.uuid4()}" |
|
created = int(time.time()) |
|
parser = CodeWhispererStreamParser() |
|
|
|
|
|
is_in_tool_call = False |
|
sent_role = False |
|
current_tool_call_index = 0 |
|
streamed_tool_calls_count = 0 |
|
content_buffer = "" |
|
incomplete_tool_call = "" |
|
|
|
async for chunk in response.aiter_bytes(): |
|
events = parser.parse(chunk) |
|
|
|
for event in events: |
|
|
|
if "name" in event and "toolUseId" in event: |
|
logger.info(f"🎯 STREAM: Found structured tool call event: {event}") |
|
|
|
if not is_in_tool_call: |
|
is_in_tool_call = True |
|
|
|
|
|
delta_start = { |
|
"tool_calls": [{ |
|
"index": current_tool_call_index, |
|
"id": event.get("toolUseId"), |
|
"type": "function", |
|
"function": {"name": event.get("name"), "arguments": ""} |
|
}] |
|
} |
|
|
|
if not sent_role: |
|
delta_start["role"] = "assistant" |
|
sent_role = True |
|
|
|
start_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=delta_start)] |
|
) |
|
yield f"data: {start_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
if "input" in event: |
|
arg_chunk_str = event.get("input", "") |
|
if arg_chunk_str: |
|
arg_chunk_delta = { |
|
"tool_calls": [{ |
|
"index": current_tool_call_index, |
|
"function": {"arguments": arg_chunk_str} |
|
}] |
|
} |
|
arg_chunk_resp = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=arg_chunk_delta)] |
|
) |
|
yield f"data: {arg_chunk_resp.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
if event.get("stop"): |
|
is_in_tool_call = False |
|
current_tool_call_index += 1 |
|
streamed_tool_calls_count += 1 |
|
|
|
|
|
elif "content" in event and not is_in_tool_call: |
|
content_text = event.get("content", "") |
|
if content_text: |
|
content_buffer += content_text |
|
logger.info(f"📝 STREAM DEBUG: Buffer updated. Length: {len(content_buffer)}. Content: >>>{content_buffer}<<<") |
|
logger.info(f"📝 STREAM DEBUG: incomplete_tool_call: >>>{incomplete_tool_call}<<<") |
|
|
|
|
|
while True: |
|
|
|
called_start = content_buffer.find("[Called") |
|
logger.info(f"🔍 BRACKET DEBUG: Searching for [Called in buffer (length={len(content_buffer)})") |
|
logger.info(f"🔍 BRACKET DEBUG: called_start={called_start}") |
|
logger.info(f"🔍 BRACKET DEBUG: Full buffer content: >>>{content_buffer}<<<") |
|
|
|
if called_start == -1: |
|
|
|
logger.info(f"🔍 BRACKET DEBUG: No [Called found, sending buffer as content") |
|
logger.info(f"🔍 BRACKET DEBUG: incomplete_tool_call status: {bool(incomplete_tool_call)}") |
|
if content_buffer and not incomplete_tool_call: |
|
delta_content = {"content": content_buffer} |
|
if not sent_role: |
|
delta_content["role"] = "assistant" |
|
sent_role = True |
|
|
|
logger.info(f"📤 STREAM: Sending content chunk: {delta_content}") |
|
content_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=delta_content)] |
|
) |
|
yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
content_buffer = "" |
|
break |
|
|
|
logger.info(f"🔍 BRACKET DEBUG: Found [Called at position {called_start}") |
|
|
|
|
|
if called_start > 0: |
|
text_before = content_buffer[:called_start] |
|
logger.info(f"🔍 BRACKET DEBUG: Text before [Called: >>>{text_before}<<<") |
|
if text_before.strip(): |
|
delta_content = {"content": text_before} |
|
if not sent_role: |
|
delta_content["role"] = "assistant" |
|
sent_role = True |
|
|
|
content_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=delta_content)] |
|
) |
|
yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
remaining_text = content_buffer[called_start:] |
|
logger.info(f"🔍 BRACKET DEBUG: Looking for matching ] in: >>>{remaining_text[:100]}...<<<") |
|
bracket_end = find_matching_bracket(remaining_text, 0) |
|
logger.info(f"🔍 BRACKET DEBUG: bracket_end={bracket_end}") |
|
|
|
if bracket_end == -1: |
|
|
|
logger.info(f"🔍 BRACKET DEBUG: Tool call incomplete, saving to incomplete_tool_call") |
|
logger.info(f"🔍 BRACKET DEBUG: Incomplete content: >>>{remaining_text}<<<") |
|
incomplete_tool_call = remaining_text |
|
content_buffer = "" |
|
break |
|
|
|
|
|
tool_call_text = remaining_text[:bracket_end + 1] |
|
logger.info(f"🔍 BRACKET DEBUG: Extracting tool call: >>>{tool_call_text}<<<") |
|
parsed_call = parse_single_tool_call(tool_call_text) |
|
logger.info(f"🔍 BRACKET DEBUG: Parsed call result: {parsed_call}") |
|
|
|
if parsed_call: |
|
|
|
delta_tool = { |
|
"tool_calls": [{ |
|
"index": current_tool_call_index, |
|
"id": parsed_call.id, |
|
"type": "function", |
|
"function": { |
|
"name": parsed_call.function["name"], |
|
"arguments": parsed_call.function["arguments"] |
|
} |
|
}] |
|
} |
|
if not sent_role: |
|
delta_tool["role"] = "assistant" |
|
sent_role = True |
|
|
|
logger.info(f"📤 STREAM: Sending tool call chunk: {delta_tool}") |
|
tool_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=delta_tool)] |
|
) |
|
yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
current_tool_call_index += 1 |
|
streamed_tool_calls_count += 1 |
|
else: |
|
logger.error(f"❌ BRACKET DEBUG: Failed to parse tool call") |
|
|
|
|
|
content_buffer = remaining_text[bracket_end + 1:] |
|
incomplete_tool_call = "" |
|
logger.info(f"🔍 BRACKET DEBUG: Updated buffer after tool call: >>>{content_buffer}<<<") |
|
|
|
|
|
logger.info(f"📊 STREAM END: Processing remaining content") |
|
logger.info(f"📊 STREAM END: incomplete_tool_call: >>>{incomplete_tool_call}<<<") |
|
logger.info(f"📊 STREAM END: content_buffer: >>>{content_buffer}<<<") |
|
|
|
if incomplete_tool_call: |
|
|
|
logger.info(f"🔄 STREAM END: Attempting to parse incomplete tool call") |
|
content_buffer = incomplete_tool_call + content_buffer |
|
incomplete_tool_call = "" |
|
|
|
|
|
called_start = content_buffer.find("[Called") |
|
if called_start == 0: |
|
bracket_end = find_matching_bracket(content_buffer, 0) |
|
logger.info(f"🔄 STREAM END: bracket_end for incomplete={bracket_end}") |
|
if bracket_end != -1: |
|
tool_call_text = content_buffer[:bracket_end + 1] |
|
parsed_call = parse_single_tool_call(tool_call_text) |
|
|
|
if parsed_call: |
|
delta_tool = { |
|
"tool_calls": [{ |
|
"index": current_tool_call_index, |
|
"id": parsed_call.id, |
|
"type": "function", |
|
"function": { |
|
"name": parsed_call.function["name"], |
|
"arguments": parsed_call.function["arguments"] |
|
} |
|
}] |
|
} |
|
if not sent_role: |
|
delta_tool["role"] = "assistant" |
|
sent_role = True |
|
|
|
logger.info(f"📤 STREAM END: Sending final tool call: {delta_tool}") |
|
tool_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=delta_tool)] |
|
) |
|
yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
current_tool_call_index += 1 |
|
streamed_tool_calls_count += 1 |
|
|
|
content_buffer = content_buffer[bracket_end + 1:] |
|
|
|
|
|
if content_buffer.strip(): |
|
logger.info(f"📤 STREAM END: Sending remaining content: >>>{content_buffer}<<<") |
|
delta_content = {"content": content_buffer} |
|
if not sent_role: |
|
delta_content["role"] = "assistant" |
|
sent_role = True |
|
|
|
content_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta=delta_content)] |
|
) |
|
yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
|
|
finish_reason = "tool_calls" if streamed_tool_calls_count > 0 else "stop" |
|
logger.info(f"🏁 STREAM FINISH: streamed_tool_calls_count={streamed_tool_calls_count}, finish_reason={finish_reason}") |
|
end_chunk = ChatCompletionStreamResponse( |
|
id=response_id, model=request.model, created=created, |
|
choices=[StreamChoice(index=0, delta={}, finish_reason=finish_reason)] |
|
) |
|
yield f"data: {end_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
return StreamingResponse( |
|
generate_stream(), |
|
media_type="text/event-stream", |
|
headers={ |
|
"Cache-Control": "no-cache", |
|
"Connection": "keep-alive", |
|
"Content-Type": "text/event-stream" |
|
} |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"❌ 流式响应生成失败: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
raise HTTPException( |
|
status_code=500, |
|
detail={ |
|
"error": { |
|
"message": f"Stream generation failed: {str(e)}", |
|
"type": "internal_server_error", |
|
"param": None, |
|
"code": "stream_error" |
|
} |
|
} |
|
) |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return {"status": "healthy", "service": "Ki2API", "version": "3.0.1"} |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Root endpoint with service information""" |
|
return { |
|
"service": "Ki2API", |
|
"description": "OpenAI-compatible API for Claude Sonnet 4 via AWS CodeWhisperer", |
|
"version": "3.0.1", |
|
"endpoints": { |
|
"models": "/v1/models", |
|
"chat": "/v1/chat/completions", |
|
"health": "/health" |
|
}, |
|
"features": { |
|
"streaming": True, |
|
"tools": True, |
|
"multiple_models": True, |
|
"xml_tool_parsing": True, |
|
"auto_token_refresh": True, |
|
"null_content_handling": True, |
|
"tool_call_deduplication": True |
|
} |
|
} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
import os |
|
port = int(os.getenv("PORT", 7860)) |
|
uvicorn.run(app, host="0.0.0.0", port=port) |