test / api /chat /chat_api.py
gaoqilan's picture
Upload 103 files
1f1b4db verified
import binascii
import json
from typing import List, Optional, AsyncGenerator
from time import time
# Remove aioredis import
# from aioredis import Redis
# import tiktoken
from api.chat import ChatConfig
from api.telemetry import TelemetryAPI
from auth.jwt_handler import JWTHandler
from config.constants import ENCRYPTION_KEY
from utils.encrypt import encrypt
from utils.http import HTTPClient
from config.constants import (
APP_LANGUAGE,
APP_NAME,
APP_VERSION,
DISPLAY_NAME,
HADWARE_INFO,
INFERENCE_URL,
SYSTEM_INFO,
)
from utils.compression import decompress_chunks
from protos import request_pb2, response_pb2
class ChatAPI:
def __init__(self, api_key: str, http_client: HTTPClient = HTTPClient()):
self.api_key = api_key
self.jwt_token = None
self.jwt_token_timestamp = 0
self.http_client = http_client
async def renew_jwt_token(self):
"""Renew JWT token asynchronously if it's expired or missing"""
current_time = time()
# Check if token is still valid (within 2500 seconds)
if self.jwt_token and current_time - self.jwt_token_timestamp < 2500:
return
jwt_handler = JWTHandler(api_key=self.api_key, http_client=self.http_client)
jwt_token = await jwt_handler.get_jwt_token()
tele = TelemetryAPI(api_key=self.api_key)
await tele.do_telemetry()
if jwt_token:
self.jwt_token = jwt_token
self.jwt_token_timestamp = current_time
async def _create_chat_request(
self,
messages: List[dict],
config: ChatConfig,
system_prompt: str = "You are a helpful assistant.",
) -> request_pb2.ChatRequestMessage:
try:
await self.renew_jwt_token()
except Exception as e:
print(e)
...
msg = request_pb2.ChatRequestMessage()
# Set client info
self._set_client_info(msg)
# Set system prompt and model config
msg.system_prompt = system_prompt
msg.model_id = config.model_id.value
msg.idk13.idk13nn = 1
msg.idk_id = 5
self._set_model_config(msg, config)
# Set tool config
# self._set_tool_config(msg)
# Convert messages
self._add_messages(msg, messages)
return msg
def _set_client_info(self, msg: request_pb2.ChatRequestMessage) -> None:
"""Set client information in the request message"""
msg.client_info.api_key = self.api_key
msg.client_info.user_jwt = self.jwt_token
msg.client_info.locale = APP_LANGUAGE
msg.client_info.extension_name = APP_NAME
msg.client_info.ide_name = APP_NAME
msg.client_info.extension_version = APP_VERSION
msg.client_info.os = SYSTEM_INFO
msg.client_info.ide_version = DISPLAY_NAME
msg.client_info.hardware = HADWARE_INFO
def _set_model_config(
self, msg: request_pb2.ChatRequestMessage, config: ChatConfig
) -> None:
"""Set model configuration"""
msg.model_config.parallel_stream = 1
msg.model_config.max_tokens = config.max_tokens
msg.model_config.temperature = config.temperature
msg.model_config.top_k = config.top_k
msg.model_config.top_P = config.top_p
def _set_special_tokens(self, msg: request_pb2.ChatRequestMessage) -> None:
msg.model_config.special_tokens.extend(
[
"<|user|>",
"<|bot|>",
"<|context_request|>",
"<|endoftext|>",
"<|end_of_turn|>",
]
)
def _set_tool_config(self, msg: request_pb2.ChatRequestMessage) -> None:
"""Set tool configuration"""
msg.tool_use.mode = "auto"
msg.tool_config.tool_name = "do_not_call"
msg.tool_config.description = "Do not call this tool."
msg.tool_config.schema = '{"$schema":"https://json-schema.org/draft/2020-12/schema","properties":{},"additionalProperties":false,"type":"object"}'
def _add_messages(
self, msg: request_pb2.ChatRequestMessage, messages: List[dict]
) -> None:
"""Add chat messages to the request"""
role_map = {"user": 1, "assistant": 2, "system": 3}
for chat_msg in messages:
role = role_map.get(chat_msg["role"], 1)
content = chat_msg["content"]
# Override system prompt
if role == 3:
if isinstance(content, str):
msg.system_prompt = content
elif isinstance(content, list):
for item in content:
if item.get("type", "") == "text" and "text" in item and isinstance(item["text"], str):
msg.system_prompt = item["text"]
break
continue
if isinstance(content, list):
pb_msg = self._create_multipart_message(role, content)
else:
pb_msg = self._create_text_message(role, content)
msg.chat_messages.append(pb_msg)
def _create_multipart_message(
self, role: int, content: List[dict]
) -> request_pb2.ChatMessage:
"""Create a message with multiple parts (text and images)"""
text_parts = []
image_parts = []
for item in content:
if item["type"] == "text":
text_parts.append(item["text"])
elif item["type"] == "image_url":
image_url = item["image_url"]["url"]
if image_url.startswith("data:image/") and "base64," in image_url:
prefix, image_data = image_url.split("base64,", 1)
mime_type = prefix.split("data:")[1].split(";")[0]
image_parts.append(
request_pb2.ImagePart(
image_data=image_data, image_mime_type=mime_type
)
)
return self._create_message(role, " ".join(text_parts), image_parts)
def _create_text_message(self, role: int, content: str) -> request_pb2.ChatMessage:
"""Create a simple text message"""
return self._create_message(role, content)
def _create_message(
self, role: int, content: str, image_parts: List[request_pb2.ImagePart] = None
) -> request_pb2.ChatMessage:
"""Create a chat message with common attributes"""
pb_msg = request_pb2.ChatMessage(
role=role, content=content
)
if role == 1: pb_msg.idk2 = 1
# pb_msg.cache_control.prompt_caching = 1
if image_parts:
pb_msg.image_parts.extend(image_parts)
return pb_msg
async def _process_chat_response(self, type: int, data: bytes) -> tuple[str, int]:
"""Process a single chat response chunk and return (message, count)"""
if type == 3: # end of message
try:
response = json.loads(data)
return (encrypt(str(response), ENCRYPTION_KEY), 0) if response else ("", 0)
except Exception as e:
raise e
try:
search_response = response_pb2.ChatResponse()
search_response.ParseFromString(data)
return (search_response.message, search_response.count) if search_response.message else ("", 0)
except:
return ("", 0)
async def _handle_stream_response(self, chunk_iterator) -> AsyncGenerator[tuple[str, int], None]:
"""Handle streaming response chunks"""
async for chunk in chunk_iterator:
for type, data in decompress_chunks(chunk):
result = await self._process_chat_response(type, data)
yield result
async def _handle_response(self, chunk) -> AsyncGenerator[tuple[str, int], None]:
"""Handle non-streaming response chunks"""
for type, data in decompress_chunks(chunk):
result = await self._process_chat_response(type, data)
yield result
async def send_message(
self,
messages: List[dict],
config: Optional[ChatConfig] = None,
system_prompt: str = "You are a helpful assistant.",
stream: bool = False,
) -> AsyncGenerator[tuple[str, int], None]:
"""Send chat messages and yield response chunks"""
if config is None:
config = ChatConfig()
request = await self._create_chat_request(messages, config, system_prompt)
headers = {
"User-Agent": "connect-go/1.16.2 (go1.23.2 X:nocoverageredesign)",
"Connect-Accept-Encoding": "gzip",
"Connect-Content-Encoding": "gzip",
"Connect-Protocol-Version": "1",
"Content-Type": "application/connect+proto",
}
url = f"{INFERENCE_URL}/exa.api_server_pb.ApiServerService/GetChatMessage"
request_data = request.SerializeToString()
if stream:
stream_iterator = self.http_client.stream_post(
url=url,
data=request_data,
headers=headers,
compress=True,
)
async for result in self._handle_stream_response(stream_iterator):
yield result
else:
response = await self.http_client.post(
url=url,
data=request_data,
headers=headers,
compress=True,
)
if response.status_code != 200:
raise Exception(f"Chat request failed: {response.status_code}")
if response.headers.get("connect-content-encoding") == "gzip":
async for result in self._handle_response(response.content):
yield result
else:
search_response = response_pb2.ChatResponse()
search_response.ParseFromString(response.content)
if search_response.message:
yield (search_response.message, search_response.count)