| import os |
| import json |
| import time |
| import uuid |
| import httpx |
| import re |
| import asyncio |
| import xml.etree.ElementTree as ET |
| import logging |
| import struct |
| import base64 |
| import copy |
| from fastapi import FastAPI, HTTPException, Request, Header, Depends |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, Field |
| from typing import List, Optional, Dict, Any, Union |
| from dotenv import load_dotenv |
| from json_repair import repair_json |
|
|
| |
| |
| logging.basicConfig(level=logging.WARNING) |
| logger = logging.getLogger(__name__) |
|
|
| |
| load_dotenv() |
|
|
| |
| app = FastAPI( |
| title="Ki2API - Claude Sonnet 4 OpenAI Compatible API", |
| description="OpenAI-compatible API for Claude Sonnet 4 via AWS CodeWhisperer", |
| version="3.0.1" |
| ) |
|
|
| |
| API_KEY = os.getenv("API_KEY") |
| KIRO_ACCESS_TOKEN = os.getenv("KIRO_ACCESS_TOKEN") |
| KIRO_REFRESH_TOKEN = os.getenv("KIRO_REFRESH_TOKEN") |
| KIRO_BASE_URL = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse" |
| PROFILE_ARN = "arn:aws:codewhisperer:us-east-1:699475941385:profile/EHGA3GRVQMUK" |
|
|
| |
| MODEL_MAP = { |
| "claude-sonnet-4-20250514": "CLAUDE_SONNET_4_20250514_V1_0", |
| "claude-3-5-haiku-20241022": "CLAUDE_3_7_SONNET_20250219_V1_0", |
| } |
| DEFAULT_MODEL = "claude-sonnet-4-20250514" |
|
|
| |
| class ImageUrl(BaseModel): |
| url: str |
| detail: Optional[str] = "auto" |
|
|
| class ContentPart(BaseModel): |
| type: str |
| text: Optional[str] = None |
| image_url: Optional[ImageUrl] = None |
|
|
| class ToolCall(BaseModel): |
| id: str |
| type: str = "function" |
| function: Dict[str, Any] |
| class ChatMessage(BaseModel): |
| role: str |
| content: Union[str, List[ContentPart], None] |
| tool_calls: Optional[List[ToolCall]] = None |
| tool_call_id: Optional[str] = None |
| |
| def get_content_text(self) -> str: |
| """Extract text content from either string or content parts""" |
| |
| if self.content is None: |
| logger.warning(f"Message with role '{self.role}' has None content") |
| return "" |
| |
| if isinstance(self.content, str): |
| return self.content |
| elif isinstance(self.content, list): |
| text_parts = [] |
| for part in self.content: |
| if isinstance(part, dict): |
| if part.get("type") == "text" and "text" in part: |
| text_parts.append(part.get("text", "")) |
| elif part.get("type") == "tool_result" and "content" in part: |
| text_parts.append(part.get("content", "")) |
| elif hasattr(part, 'text') and part.text: |
| text_parts.append(part.text) |
| return "".join(text_parts) |
| else: |
| logger.warning(f"Unexpected content type: {type(self.content)}") |
| return str(self.content) if self.content else "" |
|
|
| class Function(BaseModel): |
| name: str |
| description: Optional[str] = None |
| parameters: Optional[Dict[str, Any]] = None |
|
|
| class Tool(BaseModel): |
| type: str = "function" |
| function: Function |
|
|
|
|
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str |
| messages: List[ChatMessage] |
| temperature: Optional[float] = 0.7 |
| max_tokens: Optional[int] = 4000 |
| stream: Optional[bool] = False |
| top_p: Optional[float] = 1.0 |
| frequency_penalty: Optional[float] = 0.0 |
| presence_penalty: Optional[float] = 0.0 |
| stop: Optional[Union[str, List[str]]] = None |
| user: Optional[str] = None |
| tools: Optional[List[Tool]] = None |
| tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto" |
|
|
| class Usage(BaseModel): |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
| prompt_tokens_details: Optional[Dict[str, int]] = Field(default_factory=lambda: {"cached_tokens": 0}) |
| completion_tokens_details: Optional[Dict[str, int]] = Field(default_factory=lambda: {"reasoning_tokens": 0}) |
|
|
| class ResponseMessage(BaseModel): |
| role: str |
| content: Optional[str] = None |
| tool_calls: Optional[List[ToolCall]] = None |
|
|
| class Choice(BaseModel): |
| index: int |
| message: ResponseMessage |
| logprobs: Optional[Any] = None |
| finish_reason: str |
|
|
| class StreamChoice(BaseModel): |
| index: int |
| delta: Dict[str, Any] |
| logprobs: Optional[Any] = None |
| finish_reason: Optional[str] = None |
|
|
| class ChatCompletionResponse(BaseModel): |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") |
| object: str = "chat.completion" |
| created: int = Field(default_factory=lambda: int(time.time())) |
| model: str |
| system_fingerprint: Optional[str] = "fp_ki2api_v3" |
| choices: List[Choice] |
| usage: Usage |
|
|
| class ChatCompletionStreamResponse(BaseModel): |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") |
| object: str = "chat.completion.chunk" |
| created: int = Field(default_factory=lambda: int(time.time())) |
| model: str |
| system_fingerprint: Optional[str] = "fp_ki2api_v3" |
| choices: List[StreamChoice] |
| usage: Optional[Usage] = None |
|
|
| class ErrorResponse(BaseModel): |
| error: Dict[str, Any] |
|
|
| |
| async def verify_api_key(authorization: str = Header(None)): |
| if not authorization: |
| raise HTTPException( |
| status_code=401, |
| detail={ |
| "error": { |
| "message": "You didn't provide an API key.", |
| "type": "invalid_request_error", |
| "param": None, |
| "code": "invalid_api_key" |
| } |
| } |
| ) |
| |
| if not authorization.startswith("Bearer "): |
| raise HTTPException( |
| status_code=401, |
| detail={ |
| "error": { |
| "message": "Invalid API key format. Expected 'Bearer <key>'", |
| "type": "invalid_request_error", |
| "param": None, |
| "code": "invalid_api_key" |
| } |
| } |
| ) |
| |
| api_key = authorization.replace("Bearer ", "") |
| if api_key != API_KEY: |
| raise HTTPException( |
| status_code=401, |
| detail={ |
| "error": { |
| "message": "Invalid API key provided", |
| "type": "invalid_request_error", |
| "param": None, |
| "code": "invalid_api_key" |
| } |
| } |
| ) |
| return api_key |
|
|
| |
| class TokenManager: |
| def __init__(self): |
| self.access_token = KIRO_ACCESS_TOKEN |
| self.refresh_token = KIRO_REFRESH_TOKEN |
| self.refresh_url = "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken" |
| self.last_refresh_time = 0 |
| self.refresh_lock = asyncio.Lock() |
|
|
| async def refresh_tokens(self): |
| """刷新token,使用锁防止并发刷新请求""" |
| if not self.refresh_token: |
| logger.error("没有刷新token,无法刷新访问token") |
| return None |
| |
| async with self.refresh_lock: |
| |
| current_time = time.time() |
| if current_time - self.last_refresh_time < 5: |
| logger.info("最近已刷新token,使用现有token") |
| return self.access_token |
| |
| try: |
| logger.info("开始刷新token...") |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| self.refresh_url, |
| json={"refreshToken": self.refresh_token}, |
| timeout=30 |
| ) |
| response.raise_for_status() |
| |
| data = response.json() |
| if "accessToken" not in data: |
| logger.error(f"刷新token响应中没有accessToken: {data}") |
| return None |
| |
| self.access_token = data.get("accessToken") |
| self.last_refresh_time = current_time |
| logger.info("token刷新成功") |
| |
| |
| os.environ["KIRO_ACCESS_TOKEN"] = self.access_token |
| |
| return self.access_token |
| except Exception as e: |
| logger.error(f"token刷新失败: {str(e)}") |
| return None |
|
|
| def get_token(self): |
| return self.access_token |
|
|
| token_manager = TokenManager() |
|
|
| |
| def parse_xml_tool_calls(response_text: str) -> Optional[List[ToolCall]]: |
| """解析CodeWhisperer返回的XML格式工具调用,转换为OpenAI格式""" |
| if not response_text: |
| return None |
| |
| tool_calls = [] |
| |
| logger.info(f"🔍 开始解析XML工具调用,响应文本长度: {len(response_text)}") |
| |
| |
| tool_use_pattern = r'<tool_use>\s*<tool_name>([^<]+)</tool_name>\s*<tool_parameter_name>([^<]+)</tool_parameter_name>\s*<tool_parameter_value>([^<]*)</tool_parameter_value>\s*</tool_use>' |
| matches = re.finditer(tool_use_pattern, response_text, re.DOTALL | re.IGNORECASE) |
| |
| for match in matches: |
| function_name = match.group(1).strip() |
| param_name = match.group(2).strip() |
| param_value = match.group(3).strip() |
| |
| arguments = {param_name: param_value} |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
| |
| tool_call = ToolCall( |
| id=tool_call_id, |
| type="function", |
| function={ |
| "name": function_name, |
| "arguments": json.dumps(arguments, ensure_ascii=False) |
| } |
| ) |
| tool_calls.append(tool_call) |
| logger.info(f"✅ 解析到工具调用: {function_name} with {param_name}={param_value}") |
| |
| |
| if not tool_calls: |
| simple_pattern = r'<tool_name>([^<]+)</tool_name>\s*<tool_parameter_name>([^<]+)</tool_parameter_name>\s*<tool_parameter_value>([^<]*)</tool_parameter_value>' |
| matches = re.finditer(simple_pattern, response_text, re.DOTALL | re.IGNORECASE) |
| |
| for match in matches: |
| function_name = match.group(1).strip() |
| param_name = match.group(2).strip() |
| param_value = match.group(3).strip() |
| |
| arguments = {param_name: param_value} |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
| |
| tool_call = ToolCall( |
| id=tool_call_id, |
| type="function", |
| function={ |
| "name": function_name, |
| "arguments": json.dumps(arguments, ensure_ascii=False) |
| } |
| ) |
| tool_calls.append(tool_call) |
| logger.info(f"✅ 解析到简单工具调用: {function_name} with {param_name}={param_value}") |
| |
| |
| if not tool_calls: |
| name_only_pattern = r'<tool_name>([^<]+)</tool_name>' |
| matches = re.finditer(name_only_pattern, response_text, re.IGNORECASE) |
| |
| for match in matches: |
| function_name = match.group(1).strip() |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
| |
| tool_call = ToolCall( |
| id=tool_call_id, |
| type="function", |
| function={ |
| "name": function_name, |
| "arguments": "{}" |
| } |
| ) |
| tool_calls.append(tool_call) |
| logger.info(f"✅ 解析到无参数工具调用: {function_name}") |
| |
| if tool_calls: |
| logger.info(f"🎉 总共解析出 {len(tool_calls)} 个工具调用") |
| return tool_calls |
| else: |
| logger.info("❌ 未发现任何XML格式的工具调用") |
| return None |
|
|
| def find_matching_bracket(text: str, start_pos: int) -> int: |
| """找到匹配的结束括号位置""" |
| logger.info(f"🔧 FIND BRACKET: text length={len(text)}, start_pos={start_pos}") |
| logger.info(f"🔧 FIND BRACKET: First 100 chars: >>>{text[:100]}<<<") |
| |
| if not text or start_pos >= len(text) or text[start_pos] != '[': |
| logger.info(f"🔧 FIND BRACKET: Early return -1, text[start_pos]={text[start_pos] if start_pos < len(text) else 'OOB'}") |
| return -1 |
| |
| bracket_count = 1 |
| in_string = False |
| escape_next = False |
| |
| logger.info(f"🔧 FIND BRACKET: Starting search from position {start_pos + 1}") |
| |
| for i in range(start_pos + 1, len(text)): |
| char = text[i] |
| |
| if escape_next: |
| escape_next = False |
| continue |
| |
| if char == '\\' and in_string: |
| escape_next = True |
| continue |
| |
| if char == '"' and not escape_next: |
| in_string = not in_string |
| logger.info(f"🔧 FIND BRACKET: Toggle string mode at {i}, in_string={in_string}") |
| continue |
| |
| if not in_string: |
| if char == '[': |
| bracket_count += 1 |
| logger.info(f"🔧 FIND BRACKET: [ at {i}, bracket_count={bracket_count}") |
| elif char == ']': |
| bracket_count -= 1 |
| logger.info(f"🔧 FIND BRACKET: ] at {i}, bracket_count={bracket_count}") |
| if bracket_count == 0: |
| logger.info(f"🔧 FIND BRACKET: Found matching ] at position {i}") |
| logger.info(f"🔧 FIND BRACKET: Complete match: >>>{text[start_pos:i+1]}<<<") |
| return i |
| |
| logger.info(f"🔧 FIND BRACKET: No matching bracket found, returning -1") |
| logger.info(f"🔧 FIND BRACKET: Final bracket_count={bracket_count}") |
| return -1 |
|
|
| def parse_single_tool_call_professional(tool_call_text: str) -> Optional[ToolCall]: |
| """专业的工具调用解析器 - 使用json_repair库""" |
| logger.info(f"🔧 开始解析工具调用文本 (长度: {len(tool_call_text)})") |
| |
| |
| name_pattern = r'\[Called\s+(\w+)\s+with\s+args:' |
| name_match = re.search(name_pattern, tool_call_text, re.IGNORECASE) |
| |
| if not name_match: |
| logger.warning("⚠️ 无法从文本中提取函数名") |
| return None |
| |
| function_name = name_match.group(1).strip() |
| logger.info(f"✅ 提取到函数名: {function_name}") |
| |
| |
| |
| args_start_marker = "with args:" |
| args_start_pos = tool_call_text.lower().find(args_start_marker.lower()) |
| if args_start_pos == -1: |
| logger.error("❌ 找不到 'with args:' 标记") |
| return None |
| |
| |
| args_start = args_start_pos + len(args_start_marker) |
| |
| |
| args_end = tool_call_text.rfind(']') |
| if args_end <= args_start: |
| logger.error("❌ 找不到结束的 ']'") |
| return None |
| |
| |
| json_candidate = tool_call_text[args_start:args_end].strip() |
| logger.info(f"📝 提取的JSON候选文本长度: {len(json_candidate)}") |
| |
| |
| try: |
| |
| repaired_json = repair_json(json_candidate) |
| logger.info(f"🔧 JSON修复完成,修复后长度: {len(repaired_json)}") |
| |
| |
| arguments = json.loads(repaired_json) |
| |
| |
| if not isinstance(arguments, dict): |
| logger.error(f"❌ 解析结果不是字典类型: {type(arguments)}") |
| return None |
| |
| |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
| tool_call = ToolCall( |
| id=tool_call_id, |
| type="function", |
| function={ |
| "name": function_name, |
| "arguments": json.dumps(arguments, ensure_ascii=False) |
| } |
| ) |
| |
| logger.info(f"✅ 成功创建工具调用: {function_name} (参数键: {list(arguments.keys())})") |
| return tool_call |
| |
| except Exception as e: |
| logger.error(f"❌ JSON修复/解析失败: {type(e).__name__}: {str(e)}") |
| |
| |
| try: |
| |
| first_brace = json_candidate.find('{') |
| last_brace = json_candidate.rfind('}') |
| |
| if first_brace != -1 and last_brace > first_brace: |
| core_json = json_candidate[first_brace:last_brace + 1] |
| |
| |
| repaired_core = repair_json(core_json) |
| arguments = json.loads(repaired_core) |
| |
| if isinstance(arguments, dict): |
| tool_call_id = f"call_{uuid.uuid4().hex[:8]}" |
| tool_call = ToolCall( |
| id=tool_call_id, |
| type="function", |
| function={ |
| "name": function_name, |
| "arguments": json.dumps(arguments, ensure_ascii=False) |
| } |
| ) |
| logger.info(f"✅ 备用方案成功: {function_name}") |
| return tool_call |
| |
| except Exception as backup_error: |
| logger.error(f"❌ 备用方案也失败了: {backup_error}") |
| |
| return None |
|
|
| def parse_bracket_tool_calls_professional(response_text: str) -> Optional[List[ToolCall]]: |
| """专业的批量工具调用解析器""" |
| if not response_text or "[Called" not in response_text: |
| logger.info("📭 响应文本中没有工具调用标记") |
| return None |
| |
| tool_calls = [] |
| errors = [] |
| |
| |
| try: |
| |
| call_positions = [] |
| start = 0 |
| while True: |
| pos = response_text.find("[Called", start) |
| if pos == -1: |
| break |
| call_positions.append(pos) |
| start = pos + 1 |
| |
| logger.info(f"🔍 找到 {len(call_positions)} 个潜在的工具调用") |
| |
| for i, start_pos in enumerate(call_positions): |
| |
| |
| if i + 1 < len(call_positions): |
| end_search_limit = call_positions[i + 1] |
| else: |
| end_search_limit = len(response_text) |
| |
| |
| segment = response_text[start_pos:end_search_limit] |
| |
| |
| bracket_count = 0 |
| end_pos = -1 |
| |
| for j, char in enumerate(segment): |
| if char == '[': |
| bracket_count += 1 |
| elif char == ']': |
| bracket_count -= 1 |
| if bracket_count == 0: |
| end_pos = start_pos + j |
| break |
| |
| if end_pos == -1: |
| |
| last_bracket = segment.rfind(']') |
| if last_bracket != -1: |
| end_pos = start_pos + last_bracket |
| else: |
| logger.warning(f"⚠️ 工具调用 {i+1} 没有找到结束括号") |
| continue |
| |
| |
| tool_call_text = response_text[start_pos:end_pos + 1] |
| logger.info(f"📋 提取工具调用 {i+1}, 长度: {len(tool_call_text)}") |
| |
| |
| parsed_call = parse_single_tool_call_professional(tool_call_text) |
| if parsed_call: |
| tool_calls.append(parsed_call) |
| else: |
| errors.append(f"工具调用 {i+1} 解析失败") |
| |
| except Exception as e: |
| logger.error(f"❌ 批量解析过程出错: {type(e).__name__}: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| if tool_calls: |
| logger.info(f"🎉 成功解析 {len(tool_calls)} 个工具调用") |
| for tc in tool_calls: |
| logger.info(f" ✓ {tc.function['name']} (ID: {tc.id})") |
| |
| if errors: |
| logger.warning(f"⚠️ 有 {len(errors)} 个解析失败:") |
| for error in errors: |
| logger.warning(f" ✗ {error}") |
| |
| return tool_calls if tool_calls else None |
|
|
| |
| def parse_bracket_tool_calls(response_text: str) -> Optional[List[ToolCall]]: |
| """向后兼容的函数名""" |
| return parse_bracket_tool_calls_professional(response_text) |
|
|
| def parse_single_tool_call(tool_call_text: str) -> Optional[ToolCall]: |
| """向后兼容的函数名""" |
| return parse_single_tool_call_professional(tool_call_text) |
|
|
| |
| def deduplicate_tool_calls(tool_calls: List[Union[Dict, ToolCall]]) -> List[ToolCall]: |
| """Deduplicate tool calls based on function name and arguments""" |
| seen = set() |
| unique_tool_calls = [] |
| |
| for tool_call in tool_calls: |
| |
| if isinstance(tool_call, dict): |
| tc = ToolCall( |
| id=tool_call.get("id", f"call_{uuid.uuid4().hex[:8]}"), |
| type=tool_call.get("type", "function"), |
| function=tool_call.get("function", {}) |
| ) |
| else: |
| tc = tool_call |
| |
| |
| key = ( |
| tc.function.get("name", ""), |
| tc.function.get("arguments", "") |
| ) |
| |
| if key not in seen: |
| seen.add(key) |
| unique_tool_calls.append(tc) |
| else: |
| logger.info(f"🔄 Skipping duplicate tool call: {tc.function.get('name', 'unknown')}") |
| |
| return unique_tool_calls |
|
|
| def build_codewhisperer_request(request: ChatCompletionRequest): |
| codewhisperer_model = MODEL_MAP.get(request.model, MODEL_MAP[DEFAULT_MODEL]) |
| conversation_id = str(uuid.uuid4()) |
| |
| |
| system_prompt = "" |
| conversation_messages = [] |
| |
| for msg in request.messages: |
| if msg.role == "system": |
| system_prompt = msg.get_content_text() |
| elif msg.role in ["user", "assistant", "tool"]: |
| conversation_messages.append(msg) |
| |
| if not conversation_messages: |
| raise HTTPException( |
| status_code=400, |
| detail={ |
| "error": { |
| "message": "No conversation messages found", |
| "type": "invalid_request_error", |
| "param": "messages", |
| "code": "invalid_request" |
| } |
| } |
| ) |
| |
| |
| history = [] |
| |
| |
| if len(conversation_messages) > 1: |
| history_messages = conversation_messages[:-1] |
| |
| |
| processed_messages = [] |
| i = 0 |
| while i < len(history_messages): |
| msg = history_messages[i] |
| |
| if msg.role == "user": |
| content = msg.get_content_text() or "Continue" |
| processed_messages.append(("user", content)) |
| i += 1 |
| elif msg.role == "assistant": |
| |
| if hasattr(msg, 'tool_calls') and msg.tool_calls: |
| |
| tool_descriptions = [] |
| for tc in msg.tool_calls: |
| func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" |
| args = tc.function.get("arguments", "{}") if isinstance(tc.function, dict) else "{}" |
| tool_descriptions.append(f"[Called {func_name} with args: {args}]") |
| content = " ".join(tool_descriptions) |
| logger.info(f"📌 Processing assistant message with tool calls: {content}") |
| else: |
| content = msg.get_content_text() or "I understand." |
| processed_messages.append(("assistant", content)) |
| i += 1 |
| elif msg.role == "tool": |
| |
| tool_content = msg.get_content_text() or "[Tool executed]" |
| tool_call_id = getattr(msg, 'tool_call_id', 'unknown') |
| |
| |
| formatted_tool_result = f"[Tool result for {tool_call_id}]: {tool_content}" |
| |
| |
| if i + 1 < len(history_messages) and history_messages[i + 1].role == "user": |
| user_content = history_messages[i + 1].get_content_text() or "" |
| combined_content = f"{formatted_tool_result}\n{user_content}".strip() |
| processed_messages.append(("user", combined_content)) |
| i += 2 |
| else: |
| |
| processed_messages.append(("user", formatted_tool_result)) |
| i += 1 |
| else: |
| i += 1 |
| |
| |
| i = 0 |
| while i < len(processed_messages): |
| role, content = processed_messages[i] |
| |
| if role == "user": |
| history.append({ |
| "userInputMessage": { |
| "content": content, |
| "modelId": codewhisperer_model, |
| "origin": "AI_EDITOR" |
| } |
| }) |
| |
| |
| if i + 1 < len(processed_messages) and processed_messages[i + 1][0] == "assistant": |
| _, assistant_content = processed_messages[i + 1] |
| history.append({ |
| "assistantResponseMessage": { |
| "content": assistant_content |
| } |
| }) |
| i += 2 |
| else: |
| |
| history.append({ |
| "assistantResponseMessage": { |
| "content": "I understand." |
| } |
| }) |
| i += 1 |
| elif role == "assistant": |
| |
| history.append({ |
| "userInputMessage": { |
| "content": "Continue", |
| "modelId": codewhisperer_model, |
| "origin": "AI_EDITOR" |
| } |
| }) |
| history.append({ |
| "assistantResponseMessage": { |
| "content": content |
| } |
| }) |
| i += 1 |
| else: |
| i += 1 |
| |
| |
| current_message = conversation_messages[-1] |
|
|
| |
| images = [] |
| if isinstance(current_message.content, list): |
| for part in current_message.content: |
| if part.type == "image_url" and part.image_url: |
| try: |
| |
| logger.info(f"🔍 处理图片 URL: {part.image_url.url[:50]}...") |
| |
| |
| if not part.image_url.url.startswith("data:image/"): |
| logger.error(f"❌ 图片 URL 格式不正确,应该以 'data:image/' 开头") |
| continue |
| |
| |
| |
| header, encoded_data = part.image_url.url.split(",", 1) |
| |
| |
| |
| |
| match = re.search(r'image/(\w+)', header) |
| if match: |
| image_format = match.group(1) |
| |
| try: |
| base64.b64decode(encoded_data) |
| logger.info("✅ Base64 编码验证通过") |
| except Exception as e: |
| logger.error(f"❌ Base64 编码无效: {e}") |
| continue |
| |
| images.append({ |
| "format": image_format, |
| "source": {"bytes": encoded_data} |
| }) |
| logger.info(f"🖼️ 成功处理图片,格式: {image_format}, 大小: {len(encoded_data)} 字符") |
| else: |
| logger.warning(f"⚠️ 无法从头部确定图片格式: {header}") |
| except Exception as e: |
| logger.error(f"❌ 处理图片 URL 失败: {str(e)}") |
|
|
| current_content = current_message.get_content_text() |
| |
| |
| if current_message.role == "tool": |
| |
| tool_result = current_content or '[Tool executed]' |
| tool_call_id = getattr(current_message, 'tool_call_id', 'unknown') |
| current_content = f"[Tool execution completed for {tool_call_id}]: {tool_result}" |
| |
| |
| if len(conversation_messages) > 1: |
| prev_message = conversation_messages[-2] |
| if prev_message.role == "assistant" and hasattr(prev_message, 'tool_calls') and prev_message.tool_calls: |
| |
| for tc in prev_message.tool_calls: |
| if tc.id == tool_call_id: |
| func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" |
| current_content = f"[Completed execution of {func_name}]: {tool_result}" |
| break |
| elif current_message.role == "assistant": |
| |
| if hasattr(current_message, 'tool_calls') and current_message.tool_calls: |
| tool_descriptions = [] |
| for tc in current_message.tool_calls: |
| func_name = tc.function.get("name", "unknown") if isinstance(tc.function, dict) else "unknown" |
| tool_descriptions.append(f"Continue after calling {func_name}") |
| current_content = "; ".join(tool_descriptions) |
| else: |
| current_content = "Continue the conversation" |
| |
| |
| if not current_content: |
| current_content = "Continue" |
| |
| |
| if system_prompt: |
| current_content = f"{system_prompt}\n\n{current_content}" |
| |
| |
| codewhisperer_request = { |
| "profileArn": PROFILE_ARN, |
| "conversationState": { |
| "chatTriggerType": "MANUAL", |
| "conversationId": conversation_id, |
| "currentMessage": { |
| "userInputMessage": { |
| "content": current_content, |
| "modelId": codewhisperer_model, |
| "origin": "AI_EDITOR" |
| } |
| }, |
| "history": history |
| } |
| } |
| |
| |
| user_input_message_context = {} |
| if request.tools: |
| user_input_message_context["tools"] = [ |
| { |
| "toolSpecification": { |
| "name": tool.function.name, |
| "description": tool.function.description or "", |
| "inputSchema": {"json": tool.function.parameters or {}} |
| } |
| } for tool in request.tools |
| ] |
| |
| |
| if images: |
| |
| codewhisperer_request["conversationState"]["currentMessage"]["userInputMessage"]["images"] = images |
| logger.info(f"📊 添加了 {len(images)} 个图片到 userInputMessage 中") |
| for i, img in enumerate(images): |
| logger.info(f" - 图片 {i+1}: 格式={img['format']}, 大小={len(img['source']['bytes'])} 字符") |
| |
| logger.info(f" - 图片数据前20字符: {img['source']['bytes'][:20]}...") |
| logger.info(f"✅ 成功添加 images 到 userInputMessage 中") |
|
|
| if user_input_message_context: |
| codewhisperer_request["conversationState"]["currentMessage"]["userInputMessage"]["userInputMessageContext"] = user_input_message_context |
| logger.info(f"✅ 成功添加 userInputMessageContext 到请求中") |
| |
| |
| log_request = copy.deepcopy(codewhisperer_request) |
| |
| if "images" in log_request.get("conversationState", {}).get("currentMessage", {}).get("userInputMessage", {}): |
| for img in log_request["conversationState"]["currentMessage"]["userInputMessage"]["images"]: |
| if "bytes" in img.get("source", {}): |
| img["source"]["bytes"] = img["source"]["bytes"][:20] + "..." |
| |
| logger.info(f"🔄 COMPLETE CODEWHISPERER REQUEST: {json.dumps(log_request, indent=2)}") |
| return codewhisperer_request |
| |
| class CodeWhispererStreamParser: |
| def __init__(self): |
| self.buffer = b'' |
| self.error_count = 0 |
| self.max_errors = 5 |
|
|
| def parse(self, chunk: bytes) -> List[Dict[str, Any]]: |
| """解析AWS事件流格式的数据块""" |
| self.buffer += chunk |
| logger.debug(f"Parser received {len(chunk)} bytes. Buffer size: {len(self.buffer)}") |
| events = [] |
| |
| if len(self.buffer) < 12: |
| return [] |
| |
| while len(self.buffer) >= 12: |
| try: |
| header_bytes = self.buffer[0:8] |
| total_len, header_len = struct.unpack('>II', header_bytes) |
| |
| |
| if total_len > 2000000 or header_len > 2000000: |
| logger.error(f"Unreasonable header values: total_len={total_len}, header_len={header_len}") |
| self.buffer = self.buffer[8:] |
| self.error_count += 1 |
| if self.error_count > self.max_errors: |
| logger.error("Too many parsing errors, clearing buffer") |
| self.buffer = b'' |
| continue |
|
|
| |
| if len(self.buffer) < total_len: |
| break |
|
|
| |
| frame = self.buffer[:total_len] |
| self.buffer = self.buffer[total_len:] |
|
|
| |
| payload_start = 8 + header_len |
| payload_end = total_len - 4 |
| |
| if payload_start >= payload_end or payload_end > len(frame): |
| logger.error(f"Invalid payload bounds") |
| continue |
| |
| payload = frame[payload_start:payload_end] |
| |
| |
| try: |
| payload_str = payload.decode('utf-8', errors='ignore') |
| |
| |
| json_start_index = payload_str.find('{') |
| if json_start_index != -1: |
| json_payload = payload_str[json_start_index:] |
| event_data = json.loads(json_payload) |
| events.append(event_data) |
| logger.debug(f"Successfully parsed event: {event_data}") |
| except json.JSONDecodeError as e: |
| logger.error(f"JSON decode error: {e}") |
| continue |
|
|
| except struct.error as e: |
| logger.error(f"Struct unpack error: {e}") |
| self.buffer = self.buffer[1:] |
| self.error_count += 1 |
| if self.error_count > self.max_errors: |
| logger.error("Too many parsing errors, clearing buffer") |
| self.buffer = b'' |
| except Exception as e: |
| logger.error(f"Unexpected error during parsing: {str(e)}") |
| self.buffer = self.buffer[1:] |
| self.error_count += 1 |
| if self.error_count > self.max_errors: |
| logger.error("Too many parsing errors, clearing buffer") |
| self.buffer = b'' |
| |
| if events: |
| self.error_count = 0 |
| |
| return events |
|
|
| |
| class SimpleResponseParser: |
| @staticmethod |
| def parse_event_stream_to_json(raw_data: bytes) -> Dict[str, Any]: |
| """Simple parser for fallback (from version 1)""" |
| try: |
| if isinstance(raw_data, bytes): |
| raw_str = raw_data.decode('utf-8', errors='ignore') |
| else: |
| raw_str = str(raw_data) |
| |
| |
| json_pattern = r'\{[^{}]*"content"[^{}]*\}' |
| matches = re.findall(json_pattern, raw_str, re.DOTALL) |
| |
| if matches: |
| content_parts = [] |
| for match in matches: |
| try: |
| data = json.loads(match) |
| if 'content' in data and data['content']: |
| content_parts.append(data['content']) |
| except json.JSONDecodeError: |
| continue |
| if content_parts: |
| full_content = ''.join(content_parts) |
| return { |
| "content": full_content, |
| "tokens": len(full_content.split()) |
| } |
| |
| |
| clean_text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', raw_str) |
| clean_text = re.sub(r':event-type[^:]*:[^:]*:[^:]*:', '', clean_text) |
| clean_text = re.sub(r':content-type[^:]*:[^:]*:[^:]*:', '', clean_text) |
| |
| meaningful_text = re.sub(r'[^\w\s\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff.,!?;:()"\'-]', '', clean_text) |
| meaningful_text = re.sub(r'\s+', ' ', meaningful_text).strip() |
| |
| if meaningful_text and len(meaningful_text) > 5: |
| return { |
| "content": meaningful_text, |
| "tokens": len(meaningful_text.split()) |
| } |
| |
| return {"content": "No readable content found", "tokens": 0} |
| |
| except Exception as e: |
| return {"content": f"Error parsing response: {str(e)}", "tokens": 0} |
|
|
| |
| async def call_kiro_api(request: ChatCompletionRequest): |
| """Make API call to Kiro/CodeWhisperer with token refresh handling""" |
| token = token_manager.get_token() |
| if not token: |
| raise HTTPException( |
| status_code=401, |
| detail={ |
| "error": { |
| "message": "No access token available", |
| "type": "authentication_error", |
| "param": None, |
| "code": "invalid_api_key" |
| } |
| } |
| ) |
| |
| request_data = build_codewhisperer_request(request) |
| |
| headers = { |
| "Authorization": f"Bearer {token}", |
| "Content-Type": "application/json", |
| "Accept": "text/event-stream" if request.stream else "application/json" |
| } |
|
|
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| KIRO_BASE_URL, |
| headers=headers, |
| json=request_data, |
| timeout=120 |
| ) |
| |
| logger.info(f"📤 RESPONSE STATUS: {response.status_code}") |
| |
| if response.status_code == 403: |
| logger.info("收到403响应,尝试刷新token...") |
| new_token = await token_manager.refresh_tokens() |
| if new_token: |
| headers["Authorization"] = f"Bearer {new_token}" |
| response = await client.post( |
| KIRO_BASE_URL, |
| headers=headers, |
| json=request_data, |
| timeout=120 |
| ) |
| logger.info(f"📤 RETRY RESPONSE STATUS: {response.status_code}") |
| else: |
| raise HTTPException(status_code=401, detail="Token refresh failed") |
| |
| if response.status_code == 429: |
| raise HTTPException( |
| status_code=429, |
| detail={ |
| "error": { |
| "message": "Rate limit exceeded", |
| "type": "rate_limit_error", |
| "param": None, |
| "code": "rate_limit_exceeded" |
| } |
| } |
| ) |
| |
| response.raise_for_status() |
| return response |
| |
| except httpx.HTTPStatusError as e: |
| logger.error(f"HTTP ERROR: {e.response.status_code} - {e.response.text}") |
| raise HTTPException( |
| status_code=503, |
| detail={ |
| "error": { |
| "message": f"API call failed: {str(e)}", |
| "type": "api_error", |
| "param": None, |
| "code": "api_error" |
| } |
| } |
| ) |
| except Exception as e: |
| logger.error(f"API call failed: {str(e)}") |
| raise HTTPException( |
| status_code=503, |
| detail={ |
| "error": { |
| "message": f"API call failed: {str(e)}", |
| "type": "api_error", |
| "param": None, |
| "code": "api_error" |
| } |
| } |
| ) |
|
|
| |
| def estimate_tokens(text: str) -> int: |
| """Rough token estimation""" |
| return max(1, len(text) // 4) |
|
|
| def create_usage_stats(prompt_text: str, completion_text: str) -> Usage: |
| """Create usage statistics""" |
| prompt_tokens = estimate_tokens(prompt_text) |
| completion_tokens = estimate_tokens(completion_text) |
| return Usage( |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=prompt_tokens + completion_tokens |
| ) |
|
|
| |
| @app.get("/v1/models") |
| async def list_models(api_key: str = Depends(verify_api_key)): |
| """List available models""" |
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": model_id, |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "ki2api" |
| } |
| for model_id in MODEL_MAP.keys() |
| ] |
| } |
|
|
| @app.post("/v1/chat/completions") |
| async def create_chat_completion( |
| request: ChatCompletionRequest, |
| api_key: str = Depends(verify_api_key) |
| ): |
| """Create a chat completion""" |
| logger.info(f"📥 COMPLETE REQUEST: {request.model_dump_json(indent=2)}") |
| |
| |
| for i, msg in enumerate(request.messages): |
| if msg.content is None and msg.role != "assistant": |
| logger.warning(f"Message {i} with role '{msg.role}' has None content") |
| |
| if request.model not in MODEL_MAP: |
| raise HTTPException( |
| status_code=400, |
| detail={ |
| "error": { |
| "message": f"The model '{request.model}' does not exist or you do not have access to it.", |
| "type": "invalid_request_error", |
| "param": "model", |
| "code": "model_not_found" |
| } |
| } |
| ) |
| |
| |
| response = await create_non_streaming_response(request) |
| |
| if request.stream: |
| |
| return await convert_to_streaming_response(response) |
| else: |
| return response |
|
|
|
|
| async def convert_to_streaming_response(response: ChatCompletionResponse): |
| """将非流式响应转换为流式格式返回""" |
| async def generate_stream(): |
| |
| response_id = response.id |
| created = response.created |
| model = response.model |
| |
| |
| initial_chunk = ChatCompletionStreamResponse( |
| id=response_id, |
| model=model, |
| created=created, |
| choices=[StreamChoice( |
| index=0, |
| delta={"role": "assistant"}, |
| finish_reason=None |
| )] |
| ) |
| yield f"data: {initial_chunk.model_dump_json(exclude_none=True)}\n\n" |
| |
| |
| if response.choices and len(response.choices) > 0: |
| message = response.choices[0].message |
| |
| |
| if message.tool_calls: |
| for i, tool_call in enumerate(message.tool_calls): |
| |
| tool_chunk = ChatCompletionStreamResponse( |
| id=response_id, |
| model=model, |
| created=created, |
| choices=[StreamChoice( |
| index=0, |
| delta={ |
| "tool_calls": [{ |
| "index": i, |
| "id": tool_call.id, |
| "type": tool_call.type, |
| "function": tool_call.function |
| }] |
| }, |
| finish_reason=None |
| )] |
| ) |
| yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" |
| |
| |
| elif message.content: |
| |
| content = message.content |
| chunk_size = 50 |
| |
| for i in range(0, len(content), chunk_size): |
| chunk_text = content[i:i + chunk_size] |
| content_chunk = ChatCompletionStreamResponse( |
| id=response_id, |
| model=model, |
| created=created, |
| choices=[StreamChoice( |
| index=0, |
| delta={"content": chunk_text}, |
| finish_reason=None |
| )] |
| ) |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
| |
| await asyncio.sleep(0.01) |
| |
| |
| finish_reason = response.choices[0].finish_reason |
| end_chunk = ChatCompletionStreamResponse( |
| id=response_id, |
| model=model, |
| created=created, |
| choices=[StreamChoice( |
| index=0, |
| delta={}, |
| finish_reason=finish_reason |
| )] |
| ) |
| yield f"data: {end_chunk.model_dump_json(exclude_none=True)}\n\n" |
| |
| |
| yield "data: [DONE]\n\n" |
| |
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Content-Type": "text/event-stream" |
| } |
| ) |
|
|
| async def create_non_streaming_response(request: ChatCompletionRequest): |
| """ |
| Handles non-streaming chat completion requests. |
| It fetches the complete response from CodeWhisperer, parses it using |
| CodeWhispererStreamParser, and constructs a single OpenAI-compatible |
| ChatCompletionResponse. This version correctly handles tool calls by |
| parsing both structured event data and bracket format in text. |
| """ |
| try: |
| logger.info("🚀 开始非流式响应生成...") |
| response = await call_kiro_api(request) |
| |
| |
| logger.info(f"📤 CodeWhisperer响应状态码: {response.status_code}") |
| logger.info(f"📤 响应头: {dict(response.headers)}") |
| logger.info(f"📤 原始响应体长度: {len(response.content)} bytes") |
| |
| |
| raw_response_text = "" |
| try: |
| raw_response_text = response.content.decode('utf-8', errors='ignore') |
| logger.info(f"🔍 原始响应文本长度: {len(raw_response_text)}") |
| logger.info(f"🔍 原始响应预览(前1000字符): {raw_response_text[:1000]}") |
| |
| |
| if "[Called" in raw_response_text: |
| logger.info("✅ 原始响应中发现 [Called 标记") |
| called_positions = [m.start() for m in re.finditer(r'\[Called', raw_response_text)] |
| logger.info(f"🎯 [Called 出现位置: {called_positions}") |
| else: |
| logger.info("❌ 原始响应中未发现 [Called 标记") |
| |
| except Exception as e: |
| logger.error(f"❌ 解码原始响应失败: {e}") |
| |
| |
| parser = CodeWhispererStreamParser() |
| events = parser.parse(response.content) |
| |
| full_response_text = "" |
| tool_calls = [] |
| current_tool_call_dict = None |
|
|
| logger.info(f"🔄 解析到 {len(events)} 个事件,开始处理...") |
| |
| |
| for i, event in enumerate(events): |
| logger.info(f"📋 事件 {i}: {event}") |
|
|
| for event in events: |
| |
| if "name" in event and "toolUseId" in event: |
| logger.info(f"🔧 发现结构化工具调用事件: {event}") |
| |
| if not current_tool_call_dict: |
| current_tool_call_dict = { |
| "id": event.get("toolUseId"), |
| "type": "function", |
| "function": { |
| "name": event.get("name"), |
| "arguments": "" |
| } |
| } |
| logger.info(f"🆕 开始解析工具调用: {current_tool_call_dict['function']['name']}") |
|
|
| |
| if "input" in event: |
| current_tool_call_dict["function"]["arguments"] += event.get("input", "") |
| logger.info(f"📝 累积参数: {event.get('input', '')}") |
|
|
| |
| if event.get("stop"): |
| logger.info(f"✅ 完成工具调用: {current_tool_call_dict['function']['name']}") |
| |
| try: |
| args = json.loads(current_tool_call_dict["function"]["arguments"]) |
| current_tool_call_dict["function"]["arguments"] = json.dumps(args, ensure_ascii=False) |
| logger.info(f"✅ 工具调用参数验证成功") |
| except json.JSONDecodeError as e: |
| logger.warning(f"⚠️ 工具调用的参数不是有效的JSON: {current_tool_call_dict['function']['arguments']}") |
| logger.warning(f"⚠️ JSON错误: {e}") |
| |
| tool_calls.append(ToolCall(**current_tool_call_dict)) |
| current_tool_call_dict = None |
| |
| |
| elif "content" in event: |
| content = event.get("content", "") |
| full_response_text += content |
| logger.info(f"📄 添加文本内容: {content[:100]}...") |
|
|
| |
| if current_tool_call_dict: |
| logger.warning("⚠️ 响应流在工具调用结束前终止,仍尝试添加。") |
| tool_calls.append(ToolCall(**current_tool_call_dict)) |
|
|
| logger.info(f"📊 事件处理完成 - 文本长度: {len(full_response_text)}, 结构化工具调用: {len(tool_calls)}") |
|
|
| |
| logger.info("🔍 开始检查解析后文本中的bracket格式工具调用...") |
| bracket_tool_calls = parse_bracket_tool_calls(full_response_text) |
| if bracket_tool_calls: |
| logger.info(f"✅ 在解析后文本中发现 {len(bracket_tool_calls)} 个 bracket 格式工具调用") |
| tool_calls.extend(bracket_tool_calls) |
| |
| |
| for tc in bracket_tool_calls: |
| |
| func_name = tc.function.get("name", "unknown") |
| |
| escaped_name = re.escape(func_name) |
| |
| pattern = r'\[Called\s+' + escaped_name + r'\s+with\s+args:\s*\{[^}]*(?:\{[^}]*\}[^}]*)*\}\s*\]' |
| full_response_text = re.sub(pattern, '', full_response_text, flags=re.DOTALL) |
| |
| |
| full_response_text = re.sub(r'\s+', ' ', full_response_text).strip() |
|
|
| |
| logger.info("🔍 开始检查原始响应中的bracket格式工具调用...") |
| raw_bracket_tool_calls = parse_bracket_tool_calls(raw_response_text) |
| if raw_bracket_tool_calls and isinstance(raw_bracket_tool_calls, list): |
| logger.info(f"✅ 在原始响应中发现 {len(raw_bracket_tool_calls)} 个 bracket 格式工具调用") |
| tool_calls.extend(raw_bracket_tool_calls) |
| else: |
| logger.info("❌ 原始响应中未发现bracket格式工具调用") |
|
|
| |
| logger.info(f"🔄 去重前工具调用数量: {len(tool_calls)}") |
| unique_tool_calls = deduplicate_tool_calls(tool_calls) |
| logger.info(f"🔄 去重后工具调用数量: {len(unique_tool_calls)}") |
|
|
| |
| if unique_tool_calls: |
| logger.info(f"🔧 构建工具调用响应,包含 {len(unique_tool_calls)} 个工具调用") |
| for i, tc in enumerate(unique_tool_calls): |
| logger.info(f"🔧 工具调用 {i}: {tc.function.get('name', 'unknown')}") |
| |
| response_message = ResponseMessage( |
| role="assistant", |
| content=None, |
| tool_calls=unique_tool_calls |
| ) |
| finish_reason = "tool_calls" |
| else: |
| logger.info("📄 构建普通文本响应") |
| |
| content = full_response_text.strip() if full_response_text.strip() else "I understand." |
| logger.info(f"📄 最终文本内容: {content[:200]}...") |
| |
| response_message = ResponseMessage( |
| role="assistant", |
| content=content |
| ) |
| finish_reason = "stop" |
|
|
| choice = Choice( |
| index=0, |
| message=response_message, |
| finish_reason=finish_reason |
| ) |
|
|
| usage = create_usage_stats( |
| prompt_text=" ".join([msg.get_content_text() for msg in request.messages]), |
| completion_text=full_response_text if not unique_tool_calls else "" |
| ) |
|
|
| chat_response = ChatCompletionResponse( |
| model=request.model, |
| choices=[choice], |
| usage=usage |
| ) |
| |
| logger.info(f"📤 最终非流式响应构建完成") |
| logger.info(f"📤 响应类型: {'工具调用' if unique_tool_calls else '文本内容'}") |
| logger.info(f"📤 完整响应: {chat_response.model_dump_json(indent=2, exclude_none=True)}") |
| return chat_response |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"❌ 非流式响应处理出错: {e}") |
| import traceback |
| traceback.print_exc() |
| raise HTTPException( |
| status_code=500, |
| detail={ |
| "error": { |
| "message": f"Internal server error: {str(e)}", |
| "type": "internal_server_error", |
| "param": None, |
| "code": "internal_error" |
| } |
| } |
| ) |
|
|
| async def create_streaming_response(request: ChatCompletionRequest): |
| """ |
| Handles streaming chat completion requests. |
| This function iteratively processes the binary event stream from CodeWhisperer, |
| parsing events on the fly. It maintains state to correctly identify and |
| stream text content or tool calls in the OpenAI-compatible format. |
| """ |
| try: |
| logger.info("开始流式响应生成...") |
| response = await call_kiro_api(request) |
| |
| async def generate_stream(): |
| response_id = f"chatcmpl-{uuid.uuid4()}" |
| created = int(time.time()) |
| parser = CodeWhispererStreamParser() |
|
|
| |
| is_in_tool_call = False |
| sent_role = False |
| current_tool_call_index = 0 |
| streamed_tool_calls_count = 0 |
| content_buffer = "" |
| incomplete_tool_call = "" |
|
|
| async for chunk in response.aiter_bytes(): |
| events = parser.parse(chunk) |
| |
| for event in events: |
| |
| if "name" in event and "toolUseId" in event: |
| logger.info(f"🎯 STREAM: Found structured tool call event: {event}") |
| |
| if not is_in_tool_call: |
| is_in_tool_call = True |
| |
| |
| delta_start = { |
| "tool_calls": [{ |
| "index": current_tool_call_index, |
| "id": event.get("toolUseId"), |
| "type": "function", |
| "function": {"name": event.get("name"), "arguments": ""} |
| }] |
| } |
| |
| if not sent_role: |
| delta_start["role"] = "assistant" |
| sent_role = True |
|
|
| start_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=delta_start)] |
| ) |
| yield f"data: {start_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
| |
| if "input" in event: |
| arg_chunk_str = event.get("input", "") |
| if arg_chunk_str: |
| arg_chunk_delta = { |
| "tool_calls": [{ |
| "index": current_tool_call_index, |
| "function": {"arguments": arg_chunk_str} |
| }] |
| } |
| arg_chunk_resp = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=arg_chunk_delta)] |
| ) |
| yield f"data: {arg_chunk_resp.model_dump_json(exclude_none=True)}\n\n" |
|
|
| |
| if event.get("stop"): |
| is_in_tool_call = False |
| current_tool_call_index += 1 |
| streamed_tool_calls_count += 1 |
|
|
| |
| elif "content" in event and not is_in_tool_call: |
| content_text = event.get("content", "") |
| if content_text: |
| content_buffer += content_text |
| logger.info(f"📝 STREAM DEBUG: Buffer updated. Length: {len(content_buffer)}. Content: >>>{content_buffer}<<<") |
| logger.info(f"📝 STREAM DEBUG: incomplete_tool_call: >>>{incomplete_tool_call}<<<") |
| |
| |
| while True: |
| |
| called_start = content_buffer.find("[Called") |
| logger.info(f"🔍 BRACKET DEBUG: Searching for [Called in buffer (length={len(content_buffer)})") |
| logger.info(f"🔍 BRACKET DEBUG: called_start={called_start}") |
| logger.info(f"🔍 BRACKET DEBUG: Full buffer content: >>>{content_buffer}<<<") |
| |
| if called_start == -1: |
| |
| logger.info(f"🔍 BRACKET DEBUG: No [Called found, sending buffer as content") |
| logger.info(f"🔍 BRACKET DEBUG: incomplete_tool_call status: {bool(incomplete_tool_call)}") |
| if content_buffer and not incomplete_tool_call: |
| delta_content = {"content": content_buffer} |
| if not sent_role: |
| delta_content["role"] = "assistant" |
| sent_role = True |
| |
| logger.info(f"📤 STREAM: Sending content chunk: {delta_content}") |
| content_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=delta_content)] |
| ) |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
| content_buffer = "" |
| break |
| |
| logger.info(f"🔍 BRACKET DEBUG: Found [Called at position {called_start}") |
| |
| |
| if called_start > 0: |
| text_before = content_buffer[:called_start] |
| logger.info(f"🔍 BRACKET DEBUG: Text before [Called: >>>{text_before}<<<") |
| if text_before.strip(): |
| delta_content = {"content": text_before} |
| if not sent_role: |
| delta_content["role"] = "assistant" |
| sent_role = True |
| |
| content_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=delta_content)] |
| ) |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
| |
| |
| remaining_text = content_buffer[called_start:] |
| logger.info(f"🔍 BRACKET DEBUG: Looking for matching ] in: >>>{remaining_text[:100]}...<<<") |
| bracket_end = find_matching_bracket(remaining_text, 0) |
| logger.info(f"🔍 BRACKET DEBUG: bracket_end={bracket_end}") |
| |
| if bracket_end == -1: |
| |
| logger.info(f"🔍 BRACKET DEBUG: Tool call incomplete, saving to incomplete_tool_call") |
| logger.info(f"🔍 BRACKET DEBUG: Incomplete content: >>>{remaining_text}<<<") |
| incomplete_tool_call = remaining_text |
| content_buffer = "" |
| break |
| |
| |
| tool_call_text = remaining_text[:bracket_end + 1] |
| logger.info(f"🔍 BRACKET DEBUG: Extracting tool call: >>>{tool_call_text}<<<") |
| parsed_call = parse_single_tool_call(tool_call_text) |
| logger.info(f"🔍 BRACKET DEBUG: Parsed call result: {parsed_call}") |
| |
| if parsed_call: |
| |
| delta_tool = { |
| "tool_calls": [{ |
| "index": current_tool_call_index, |
| "id": parsed_call.id, |
| "type": "function", |
| "function": { |
| "name": parsed_call.function["name"], |
| "arguments": parsed_call.function["arguments"] |
| } |
| }] |
| } |
| if not sent_role: |
| delta_tool["role"] = "assistant" |
| sent_role = True |
| |
| logger.info(f"📤 STREAM: Sending tool call chunk: {delta_tool}") |
| tool_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=delta_tool)] |
| ) |
| yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" |
| current_tool_call_index += 1 |
| streamed_tool_calls_count += 1 |
| else: |
| logger.error(f"❌ BRACKET DEBUG: Failed to parse tool call") |
| |
| |
| content_buffer = remaining_text[bracket_end + 1:] |
| incomplete_tool_call = "" |
| logger.info(f"🔍 BRACKET DEBUG: Updated buffer after tool call: >>>{content_buffer}<<<") |
|
|
| |
| logger.info(f"📊 STREAM END: Processing remaining content") |
| logger.info(f"📊 STREAM END: incomplete_tool_call: >>>{incomplete_tool_call}<<<") |
| logger.info(f"📊 STREAM END: content_buffer: >>>{content_buffer}<<<") |
| |
| if incomplete_tool_call: |
| |
| logger.info(f"🔄 STREAM END: Attempting to parse incomplete tool call") |
| content_buffer = incomplete_tool_call + content_buffer |
| incomplete_tool_call = "" |
| |
| |
| called_start = content_buffer.find("[Called") |
| if called_start == 0: |
| bracket_end = find_matching_bracket(content_buffer, 0) |
| logger.info(f"🔄 STREAM END: bracket_end for incomplete={bracket_end}") |
| if bracket_end != -1: |
| tool_call_text = content_buffer[:bracket_end + 1] |
| parsed_call = parse_single_tool_call(tool_call_text) |
| |
| if parsed_call: |
| delta_tool = { |
| "tool_calls": [{ |
| "index": current_tool_call_index, |
| "id": parsed_call.id, |
| "type": "function", |
| "function": { |
| "name": parsed_call.function["name"], |
| "arguments": parsed_call.function["arguments"] |
| } |
| }] |
| } |
| if not sent_role: |
| delta_tool["role"] = "assistant" |
| sent_role = True |
| |
| logger.info(f"📤 STREAM END: Sending final tool call: {delta_tool}") |
| tool_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=delta_tool)] |
| ) |
| yield f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" |
| current_tool_call_index += 1 |
| streamed_tool_calls_count += 1 |
| |
| content_buffer = content_buffer[bracket_end + 1:] |
|
|
| |
| if content_buffer.strip(): |
| logger.info(f"📤 STREAM END: Sending remaining content: >>>{content_buffer}<<<") |
| delta_content = {"content": content_buffer} |
| if not sent_role: |
| delta_content["role"] = "assistant" |
| sent_role = True |
| |
| content_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta=delta_content)] |
| ) |
| yield f"data: {content_chunk.model_dump_json(exclude_none=True)}\n\n" |
|
|
| |
| finish_reason = "tool_calls" if streamed_tool_calls_count > 0 else "stop" |
| logger.info(f"🏁 STREAM FINISH: streamed_tool_calls_count={streamed_tool_calls_count}, finish_reason={finish_reason}") |
| end_chunk = ChatCompletionStreamResponse( |
| id=response_id, model=request.model, created=created, |
| choices=[StreamChoice(index=0, delta={}, finish_reason=finish_reason)] |
| ) |
| yield f"data: {end_chunk.model_dump_json(exclude_none=True)}\n\n" |
| |
| yield "data: [DONE]\n\n" |
| |
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Content-Type": "text/event-stream" |
| } |
| ) |
| |
| except Exception as e: |
| logger.error(f"❌ 流式响应生成失败: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| raise HTTPException( |
| status_code=500, |
| detail={ |
| "error": { |
| "message": f"Stream generation failed: {str(e)}", |
| "type": "internal_server_error", |
| "param": None, |
| "code": "stream_error" |
| } |
| } |
| ) |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return {"status": "healthy", "service": "Ki2API", "version": "3.0.1"} |
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint with service information""" |
| return { |
| "service": "Ki2API", |
| "description": "OpenAI-compatible API for Claude Sonnet 4 via AWS CodeWhisperer", |
| "version": "3.0.1", |
| "endpoints": { |
| "models": "/v1/models", |
| "chat": "/v1/chat/completions", |
| "health": "/health" |
| }, |
| "features": { |
| "streaming": True, |
| "tools": True, |
| "multiple_models": True, |
| "xml_tool_parsing": True, |
| "auto_token_refresh": True, |
| "null_content_handling": True, |
| "tool_call_deduplication": True |
| } |
| } |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| import os |
| port = int(os.getenv("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |