|
|
import binascii |
|
|
import json |
|
|
from typing import List, Optional, AsyncGenerator |
|
|
from time import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from api.chat import ChatConfig |
|
|
from api.telemetry import TelemetryAPI |
|
|
from auth.jwt_handler import JWTHandler |
|
|
from config.constants import ENCRYPTION_KEY |
|
|
from utils.encrypt import encrypt |
|
|
from utils.http import HTTPClient |
|
|
from config.constants import ( |
|
|
APP_LANGUAGE, |
|
|
APP_NAME, |
|
|
APP_VERSION, |
|
|
DISPLAY_NAME, |
|
|
HADWARE_INFO, |
|
|
INFERENCE_URL, |
|
|
SYSTEM_INFO, |
|
|
) |
|
|
from utils.compression import decompress_chunks |
|
|
|
|
|
from protos import request_pb2, response_pb2 |
|
|
|
|
|
|
|
|
class ChatAPI: |
|
|
def __init__(self, api_key: str, http_client: HTTPClient = HTTPClient()): |
|
|
self.api_key = api_key |
|
|
self.jwt_token = None |
|
|
self.jwt_token_timestamp = 0 |
|
|
self.http_client = http_client |
|
|
|
|
|
async def renew_jwt_token(self): |
|
|
"""Renew JWT token asynchronously if it's expired or missing""" |
|
|
current_time = time() |
|
|
|
|
|
|
|
|
if self.jwt_token and current_time - self.jwt_token_timestamp < 2500: |
|
|
return |
|
|
|
|
|
jwt_handler = JWTHandler(api_key=self.api_key, http_client=self.http_client) |
|
|
jwt_token = await jwt_handler.get_jwt_token() |
|
|
tele = TelemetryAPI(api_key=self.api_key) |
|
|
await tele.do_telemetry() |
|
|
|
|
|
if jwt_token: |
|
|
self.jwt_token = jwt_token |
|
|
self.jwt_token_timestamp = current_time |
|
|
|
|
|
async def _create_chat_request( |
|
|
self, |
|
|
messages: List[dict], |
|
|
config: ChatConfig, |
|
|
system_prompt: str = "You are a helpful assistant.", |
|
|
) -> request_pb2.ChatRequestMessage: |
|
|
try: |
|
|
await self.renew_jwt_token() |
|
|
except Exception as e: |
|
|
print(e) |
|
|
... |
|
|
|
|
|
msg = request_pb2.ChatRequestMessage() |
|
|
|
|
|
|
|
|
self._set_client_info(msg) |
|
|
|
|
|
|
|
|
msg.system_prompt = system_prompt |
|
|
msg.model_id = config.model_id.value |
|
|
msg.idk13.idk13nn = 1 |
|
|
msg.idk_id = 5 |
|
|
self._set_model_config(msg, config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._add_messages(msg, messages) |
|
|
|
|
|
return msg |
|
|
|
|
|
def _set_client_info(self, msg: request_pb2.ChatRequestMessage) -> None: |
|
|
"""Set client information in the request message""" |
|
|
msg.client_info.api_key = self.api_key |
|
|
msg.client_info.user_jwt = self.jwt_token |
|
|
msg.client_info.locale = APP_LANGUAGE |
|
|
msg.client_info.extension_name = APP_NAME |
|
|
msg.client_info.ide_name = APP_NAME |
|
|
msg.client_info.extension_version = APP_VERSION |
|
|
msg.client_info.os = SYSTEM_INFO |
|
|
msg.client_info.ide_version = DISPLAY_NAME |
|
|
msg.client_info.hardware = HADWARE_INFO |
|
|
|
|
|
def _set_model_config( |
|
|
self, msg: request_pb2.ChatRequestMessage, config: ChatConfig |
|
|
) -> None: |
|
|
"""Set model configuration""" |
|
|
msg.model_config.parallel_stream = 1 |
|
|
msg.model_config.max_tokens = config.max_tokens |
|
|
msg.model_config.temperature = config.temperature |
|
|
msg.model_config.top_k = config.top_k |
|
|
msg.model_config.top_P = config.top_p |
|
|
|
|
|
def _set_special_tokens(self, msg: request_pb2.ChatRequestMessage) -> None: |
|
|
msg.model_config.special_tokens.extend( |
|
|
[ |
|
|
"<|user|>", |
|
|
"<|bot|>", |
|
|
"<|context_request|>", |
|
|
"<|endoftext|>", |
|
|
"<|end_of_turn|>", |
|
|
] |
|
|
) |
|
|
|
|
|
def _set_tool_config(self, msg: request_pb2.ChatRequestMessage) -> None: |
|
|
"""Set tool configuration""" |
|
|
msg.tool_use.mode = "auto" |
|
|
msg.tool_config.tool_name = "do_not_call" |
|
|
msg.tool_config.description = "Do not call this tool." |
|
|
msg.tool_config.schema = '{"$schema":"https://json-schema.org/draft/2020-12/schema","properties":{},"additionalProperties":false,"type":"object"}' |
|
|
|
|
|
def _add_messages( |
|
|
self, msg: request_pb2.ChatRequestMessage, messages: List[dict] |
|
|
) -> None: |
|
|
"""Add chat messages to the request""" |
|
|
role_map = {"user": 1, "assistant": 2, "system": 3} |
|
|
|
|
|
for chat_msg in messages: |
|
|
role = role_map.get(chat_msg["role"], 1) |
|
|
content = chat_msg["content"] |
|
|
|
|
|
|
|
|
if role == 3: |
|
|
if isinstance(content, str): |
|
|
msg.system_prompt = content |
|
|
|
|
|
elif isinstance(content, list): |
|
|
for item in content: |
|
|
if item.get("type", "") == "text" and "text" in item and isinstance(item["text"], str): |
|
|
msg.system_prompt = item["text"] |
|
|
break |
|
|
continue |
|
|
|
|
|
if isinstance(content, list): |
|
|
pb_msg = self._create_multipart_message(role, content) |
|
|
else: |
|
|
pb_msg = self._create_text_message(role, content) |
|
|
|
|
|
msg.chat_messages.append(pb_msg) |
|
|
|
|
|
def _create_multipart_message( |
|
|
self, role: int, content: List[dict] |
|
|
) -> request_pb2.ChatMessage: |
|
|
"""Create a message with multiple parts (text and images)""" |
|
|
text_parts = [] |
|
|
image_parts = [] |
|
|
|
|
|
for item in content: |
|
|
if item["type"] == "text": |
|
|
text_parts.append(item["text"]) |
|
|
elif item["type"] == "image_url": |
|
|
image_url = item["image_url"]["url"] |
|
|
if image_url.startswith("data:image/") and "base64," in image_url: |
|
|
prefix, image_data = image_url.split("base64,", 1) |
|
|
mime_type = prefix.split("data:")[1].split(";")[0] |
|
|
image_parts.append( |
|
|
request_pb2.ImagePart( |
|
|
image_data=image_data, image_mime_type=mime_type |
|
|
) |
|
|
) |
|
|
|
|
|
return self._create_message(role, " ".join(text_parts), image_parts) |
|
|
|
|
|
def _create_text_message(self, role: int, content: str) -> request_pb2.ChatMessage: |
|
|
"""Create a simple text message""" |
|
|
return self._create_message(role, content) |
|
|
|
|
|
def _create_message( |
|
|
self, role: int, content: str, image_parts: List[request_pb2.ImagePart] = None |
|
|
) -> request_pb2.ChatMessage: |
|
|
"""Create a chat message with common attributes""" |
|
|
pb_msg = request_pb2.ChatMessage( |
|
|
role=role, content=content |
|
|
) |
|
|
if role == 1: pb_msg.idk2 = 1 |
|
|
|
|
|
|
|
|
if image_parts: |
|
|
pb_msg.image_parts.extend(image_parts) |
|
|
return pb_msg |
|
|
|
|
|
async def _process_chat_response(self, type: int, data: bytes) -> tuple[str, int]: |
|
|
"""Process a single chat response chunk and return (message, count)""" |
|
|
if type == 3: |
|
|
try: |
|
|
response = json.loads(data) |
|
|
return (encrypt(str(response), ENCRYPTION_KEY), 0) if response else ("", 0) |
|
|
except Exception as e: |
|
|
raise e |
|
|
|
|
|
try: |
|
|
search_response = response_pb2.ChatResponse() |
|
|
search_response.ParseFromString(data) |
|
|
return (search_response.message, search_response.count) if search_response.message else ("", 0) |
|
|
except: |
|
|
return ("", 0) |
|
|
|
|
|
async def _handle_stream_response(self, chunk_iterator) -> AsyncGenerator[tuple[str, int], None]: |
|
|
"""Handle streaming response chunks""" |
|
|
async for chunk in chunk_iterator: |
|
|
for type, data in decompress_chunks(chunk): |
|
|
result = await self._process_chat_response(type, data) |
|
|
yield result |
|
|
|
|
|
async def _handle_response(self, chunk) -> AsyncGenerator[tuple[str, int], None]: |
|
|
"""Handle non-streaming response chunks""" |
|
|
for type, data in decompress_chunks(chunk): |
|
|
result = await self._process_chat_response(type, data) |
|
|
yield result |
|
|
|
|
|
async def send_message( |
|
|
self, |
|
|
messages: List[dict], |
|
|
config: Optional[ChatConfig] = None, |
|
|
system_prompt: str = "You are a helpful assistant.", |
|
|
stream: bool = False, |
|
|
) -> AsyncGenerator[tuple[str, int], None]: |
|
|
"""Send chat messages and yield response chunks""" |
|
|
if config is None: |
|
|
config = ChatConfig() |
|
|
|
|
|
request = await self._create_chat_request(messages, config, system_prompt) |
|
|
|
|
|
headers = { |
|
|
"User-Agent": "connect-go/1.16.2 (go1.23.2 X:nocoverageredesign)", |
|
|
"Connect-Accept-Encoding": "gzip", |
|
|
"Connect-Content-Encoding": "gzip", |
|
|
"Connect-Protocol-Version": "1", |
|
|
"Content-Type": "application/connect+proto", |
|
|
} |
|
|
|
|
|
url = f"{INFERENCE_URL}/exa.api_server_pb.ApiServerService/GetChatMessage" |
|
|
request_data = request.SerializeToString() |
|
|
|
|
|
if stream: |
|
|
stream_iterator = self.http_client.stream_post( |
|
|
url=url, |
|
|
data=request_data, |
|
|
headers=headers, |
|
|
compress=True, |
|
|
) |
|
|
async for result in self._handle_stream_response(stream_iterator): |
|
|
yield result |
|
|
else: |
|
|
response = await self.http_client.post( |
|
|
url=url, |
|
|
data=request_data, |
|
|
headers=headers, |
|
|
compress=True, |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
raise Exception(f"Chat request failed: {response.status_code}") |
|
|
|
|
|
if response.headers.get("connect-content-encoding") == "gzip": |
|
|
async for result in self._handle_response(response.content): |
|
|
yield result |
|
|
else: |
|
|
search_response = response_pb2.ChatResponse() |
|
|
search_response.ParseFromString(response.content) |
|
|
if search_response.message: |
|
|
yield (search_response.message, search_response.count) |
|
|
|