Spaces:
Sleeping
Sleeping
from datetime import datetime | |
import json | |
import logging | |
import os | |
import time | |
import traceback | |
from typing import List, Union, AsyncGenerator | |
from uuid import uuid4 | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
from sse_starlette.sse import EventSourceResponse, AsyncContentStream | |
from openai import AsyncClient, APIStatusError, APIResponseValidationError, APIError, OpenAIError | |
from openai.types.chat import ChatCompletion | |
import tiktoken | |
from .proxy import ProxyBase, RequestFilterBase, ResponseFilterBase, RequestFilterException, ResponseFilterException | |
from .accesslog import _AccessLogBase, RequestItemBase, ResponseItemBase, StreamChunkItemBase, ErrorItemBase | |
from .queueclient import QueueClientBase | |
logger = logging.getLogger(__name__) | |
class ChatGPTRequestItem(RequestItemBase): | |
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase: | |
request_headers_copy = self.request_headers.copy() | |
if auth := request_headers_copy.get("authorization"): | |
request_headers_copy["authorization"] = auth[:12] + "*****" + auth[-2:] | |
content = self.request_json["messages"][-1]["content"] | |
if isinstance(content, list): | |
for c in content: | |
if c["type"] == "text": | |
content = c["text"] | |
break | |
else: | |
content = json.dumps(content) | |
accesslog = accesslog_cls( | |
request_id=self.request_id, | |
created_at=datetime.utcnow(), | |
direction="request", | |
content=content, | |
raw_body=json.dumps(self.request_json, ensure_ascii=False), | |
raw_headers=json.dumps(request_headers_copy, ensure_ascii=False), | |
model=self.request_json.get("model") | |
) | |
return accesslog | |
class ChatGPTResponseItem(ResponseItemBase): | |
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase: | |
content = self.response_json["choices"][0]["message"].get("content") | |
function_call = self.response_json["choices"][0]["message"].get("function_call") | |
tool_calls = self.response_json["choices"][0]["message"].get("tool_calls") | |
response_headers = json.dumps(dict(self.response_headers.items()), | |
ensure_ascii=False) if self.response_headers is not None else None | |
model = self.response_json["model"] | |
prompt_tokens = self.response_json["usage"]["prompt_tokens"] | |
completion_tokens = self.response_json["usage"]["completion_tokens"] | |
return accesslog_cls( | |
request_id=self.request_id, | |
created_at=datetime.utcnow(), | |
direction="response", | |
status_code=self.status_code, | |
content=content, | |
function_call=json.dumps(function_call, ensure_ascii=False) if function_call is not None else None, | |
tool_calls=json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None, | |
raw_body=json.dumps(self.response_json, ensure_ascii=False), | |
raw_headers=response_headers, | |
model=model, | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
request_time=self.duration, | |
request_time_api=self.duration_api | |
) | |
token_encoder = tiktoken.get_encoding("cl100k_base") | |
def count_token(content: str): | |
return len(token_encoder.encode(content)) | |
def count_request_token(request_json: dict): | |
tokens_per_message = 3 | |
tokens_per_name = 1 | |
token_count = 0 | |
# messages | |
for m in request_json["messages"]: | |
token_count += tokens_per_message | |
for k, v in m.items(): | |
if isinstance(v, list): | |
for c in v: | |
if c.get("type") == "text": | |
token_count += count_token(c["text"]) | |
else: | |
token_count += count_token(v) | |
if k == "name": | |
token_count += tokens_per_name | |
# functions | |
if functions := request_json.get("functions"): | |
for f in functions: | |
token_count += count_token(json.dumps(f)) | |
# function_call | |
if function_call := request_json.get("function_call"): | |
if isinstance(function_call, dict): | |
token_count += count_token(json.dumps(function_call)) | |
else: | |
token_count += count_token(function_call) | |
# tools | |
if tools := request_json.get("tools"): | |
for t in tools: | |
token_count += count_token(json.dumps(t)) | |
if tool_choice := request_json.get("tool_choice"): | |
token_count += count_token(json.dumps(tool_choice)) | |
token_count += 3 | |
return token_count | |
class ChatGPTStreamResponseItem(StreamChunkItemBase): | |
def to_accesslog(self, chunks: list, accesslog_cls: _AccessLogBase) -> _AccessLogBase: | |
chunk_jsons = [] | |
response_content = "" | |
function_call = None | |
tool_calls = None | |
prompt_tokens = 0 | |
completion_tokens = 0 | |
# Parse info from chunks | |
for chunk in chunks: | |
chunk_jsons.append(chunk.chunk_json) | |
if len(chunk.chunk_json["choices"]) == 0: | |
# Azure returns the first delta with empty choices | |
continue | |
delta = chunk.chunk_json["choices"][0]["delta"] | |
# Make tool_calls | |
if delta.get("tool_calls"): | |
if tool_calls is None: | |
tool_calls = [] | |
if delta["tool_calls"][0]["function"].get("name"): | |
tool_calls.append({ | |
"type": "function", | |
"function": { | |
"name": delta["tool_calls"][0]["function"]["name"], | |
"arguments": "" | |
} | |
}) | |
elif delta["tool_calls"][0]["function"].get("arguments"): | |
tool_calls[-1]["function"]["arguments"] += delta["tool_calls"][0]["function"].get("arguments") or "" | |
# Make function_call | |
elif delta.get("function_call"): | |
if function_call is None: | |
function_call = {} | |
if delta["function_call"].get("name"): | |
function_call["name"] = delta["function_call"]["name"] | |
function_call["arguments"] = "" | |
elif delta["function_call"].get("arguments"): | |
function_call["arguments"] += delta["function_call"]["arguments"] | |
# Text content | |
else: | |
response_content += delta.get("content") or "" | |
# Serialize | |
function_call_str = json.dumps(function_call, ensure_ascii=False) if function_call is not None else None | |
tool_calls_str = json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None | |
response_headers = json.dumps(dict(self.response_headers.items()), | |
ensure_ascii=False) if self.response_headers is not None else None | |
# Count tokens | |
prompt_tokens = count_request_token(self.request_json) | |
if tool_calls_str: | |
completion_tokens = count_token(tool_calls_str) | |
elif function_call_str: | |
completion_tokens = count_token(function_call_str) | |
else: | |
completion_tokens = count_token(response_content) | |
return accesslog_cls( | |
request_id=self.request_id, | |
created_at=datetime.utcnow(), | |
direction="response", | |
status_code=self.status_code, | |
content=response_content, | |
function_call=function_call_str, | |
tool_calls=tool_calls_str, | |
raw_body=json.dumps(chunk_jsons, ensure_ascii=False), | |
raw_headers=response_headers, | |
model=chunk_jsons[0]["model"], | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
request_time=self.duration, | |
request_time_api=self.duration_api | |
) | |
class ChatGPTErrorItem(ErrorItemBase): | |
... | |
queue_item_types = [ChatGPTRequestItem, ChatGPTResponseItem, ChatGPTStreamResponseItem, ChatGPTErrorItem] | |
# Reverse aiproxy application for ChatGPT | |
class ChatGPTProxy(ProxyBase): | |
_empty_openai_api_key = "OPENAI_API_KEY_IS_NOT_SET" | |
def __init__( | |
self, | |
*, | |
base_url: str = None, | |
api_key: str = None, | |
async_client: AsyncClient = None, | |
max_retries: int = 0, | |
timeout: float = 60.0, | |
request_filters: List[RequestFilterBase] = None, | |
response_filters: List[ResponseFilterBase] = None, | |
request_item_class: type = ChatGPTRequestItem, | |
response_item_class: type = ChatGPTResponseItem, | |
stream_response_item_class: type = ChatGPTStreamResponseItem, | |
error_item_class: type = ChatGPTErrorItem, | |
access_logger_queue: QueueClientBase, | |
): | |
super().__init__( | |
request_filters=request_filters, | |
response_filters=response_filters, | |
access_logger_queue=access_logger_queue | |
) | |
# Log items | |
self.request_item_class = request_item_class | |
self.response_item_class = response_item_class | |
self.stream_response_item_class = stream_response_item_class | |
self.error_item_class = error_item_class | |
# ChatGPT client config | |
self.base_url = base_url | |
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or self._empty_openai_api_key | |
self.max_retries = max_retries | |
self.timeout = timeout | |
self.async_client = async_client | |
async def filter_request(self, request_id: str, request_json: dict, request_headers: dict) -> Union[ | |
dict, JSONResponse, EventSourceResponse]: | |
for f in self.request_filters: | |
if json_resp := await f.filter(request_id, request_json, request_headers): | |
# Return response if filter returns string | |
resp_for_log = { | |
"id": "-", | |
"choices": [ | |
{"message": {"role": "assistant", "content": json_resp}, "finish_reason": "stop", "index": 0}], | |
"created": 0, | |
"model": "request_filter", | |
"object": "chat.completion", | |
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
} | |
# Response log | |
self.access_logger_queue.put(self.response_item_class( | |
request_id=request_id, | |
response_json=resp_for_log, | |
status_code=200 | |
)) | |
if request_json.get("stream"): | |
# Stream | |
async def filter_response_stream(content: str): | |
# First delta | |
resp = { | |
"id": "-", | |
"choices": [ | |
{"delta": {"role": "assistant", "content": ""}, "finish_reason": None, "index": 0}], | |
"created": 0, | |
"model": "request_filter", | |
"object": "chat.completion", | |
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
} | |
yield json.dumps(resp) | |
# Last delta | |
resp["choices"][0] = {"delta": {"content": content}, "finish_reason": "stop", "index": 0} | |
yield json.dumps(resp) | |
return self.return_response_with_headers(EventSourceResponse( | |
filter_response_stream(json_resp) | |
), request_id) | |
else: | |
# Non-stream | |
return self.return_response_with_headers(JSONResponse(resp_for_log), request_id) | |
return request_json | |
def get_client(self): | |
return self.async_client or AsyncClient( | |
base_url=self.base_url, | |
api_key=self.api_key, | |
max_retries=self.max_retries, | |
timeout=self.timeout | |
) | |
async def filter_response(self, request_id: str, response: ChatCompletion) -> ChatCompletion: | |
response_json = response.model_dump() | |
for f in self.response_filters: | |
if json_resp := await f.filter(request_id, response_json): | |
return response.model_validate(json_resp) | |
return response.model_validate(response_json) | |
def return_response_with_headers(self, resp: JSONResponse, request_id: str): | |
self.add_response_headers(response=resp, request_id=request_id) | |
return resp | |
def add_route(self, app: FastAPI, base_url: str): | |
async def handle_request(request: Request): | |
request_id = str(uuid4()) | |
async_client = None | |
try: | |
start_time = time.time() | |
request_json = await request.json() | |
request_headers = dict(request.headers.items()) | |
# Log request | |
self.access_logger_queue.put(self.request_item_class( | |
request_id=request_id, | |
request_json=request_json, | |
request_headers=request_headers | |
)) | |
# Filter request | |
request_json = await self.filter_request(request_id, request_json, request_headers) | |
if isinstance(request_json, JSONResponse) or isinstance(request_json, EventSourceResponse): | |
return request_json | |
# Call API | |
async_client = self.get_client() | |
start_time_api = time.time() | |
if self.api_key != self._empty_openai_api_key: | |
# Always use server api key if set to client | |
raw_response = await async_client.chat.completions.with_raw_response.create(**request_json) | |
elif user_auth_header := request_headers.get("authorization"): # Lower case from client. | |
raw_response = await async_client.chat.completions.with_raw_response.create( | |
**request_json, extra_headers={"Authorization": user_auth_header} # Pascal to server | |
) | |
else: | |
# Call API anyway ;) | |
raw_response = await async_client.chat.completions.with_raw_response.create(**request_json) | |
completion_response = raw_response.parse() | |
completion_response_headers = raw_response.headers | |
completion_status_code = raw_response.status_code | |
if "content-encoding" in completion_response_headers: | |
completion_response_headers.pop( | |
"content-encoding") # Remove "br" that will be changed by this aiproxy | |
# Handling response from API | |
if request_json.get("stream"): | |
async def process_stream(stream: AsyncContentStream) -> AsyncGenerator[str, None]: | |
# Async content generator | |
try: | |
async for chunk in stream: | |
self.access_logger_queue.put(self.stream_response_item_class( | |
request_id=request_id, | |
chunk_json=chunk.model_dump() | |
)) | |
if chunk: | |
yield chunk.model_dump_json() | |
finally: | |
# Close client after reading stream | |
await async_client.close() | |
# Response log | |
now = time.time() | |
self.access_logger_queue.put(self.stream_response_item_class( | |
request_id=request_id, | |
response_headers=completion_response_headers, | |
duration=now - start_time, | |
duration_api=now - start_time_api, | |
request_json=request_json, | |
status_code=completion_status_code | |
)) | |
return self.return_response_with_headers(EventSourceResponse( | |
process_stream(completion_response), | |
headers=completion_response_headers | |
), request_id) | |
else: | |
# Close client immediately | |
await async_client.close() | |
duration_api = time.time() - start_time_api | |
# Filter response | |
completion_response = await self.filter_response(request_id, completion_response) | |
# Response log | |
self.access_logger_queue.put(self.response_item_class( | |
request_id=request_id, | |
response_json=completion_response.model_dump(), | |
response_headers=completion_response_headers, | |
duration=time.time() - start_time, | |
duration_api=duration_api, | |
status_code=completion_status_code | |
)) | |
return self.return_response_with_headers(JSONResponse( | |
content=completion_response.model_dump(), | |
headers=completion_response_headers | |
), request_id) | |
# Error handlers | |
except RequestFilterException as rfex: | |
logger.error(f"Request filter error: {rfex}\n{traceback.format_exc()}") | |
resp_json = { | |
"error": {"message": rfex.message, "type": "request_filter_error", "param": None, "code": None}} | |
# Error log | |
self.access_logger_queue.put(self.error_item_class( | |
request_id=request_id, | |
exception=rfex, | |
traceback_info=traceback.format_exc(), | |
response_json=resp_json, | |
status_code=rfex.status_code | |
)) | |
return self.return_response_with_headers(JSONResponse(resp_json, status_code=rfex.status_code), | |
request_id) | |
except ResponseFilterException as rfex: | |
logger.error(f"Response filter error: {rfex}\n{traceback.format_exc()}") | |
resp_json = { | |
"error": {"message": rfex.message, "type": "response_filter_error", "param": None, "code": None}} | |
# Error log | |
self.access_logger_queue.put(self.error_item_class( | |
request_id=request_id, | |
exception=rfex, | |
traceback_info=traceback.format_exc(), | |
response_json=resp_json, | |
status_code=rfex.status_code | |
)) | |
return self.return_response_with_headers(JSONResponse(resp_json, status_code=rfex.status_code), | |
request_id) | |
except (APIStatusError, APIResponseValidationError) as status_err: | |
logger.error(f"APIStatusError from ChatGPT: {status_err}\n{traceback.format_exc()}") | |
# Error log | |
try: | |
resp_json = status_err.response.json() | |
except: | |
resp_json = str(status_err.response.content) | |
self.access_logger_queue.put(self.error_item_class( | |
request_id=request_id, | |
exception=status_err, | |
traceback_info=traceback.format_exc(), | |
response_json=resp_json, | |
status_code=status_err.status_code | |
)) | |
return self.return_response_with_headers(JSONResponse(resp_json, status_code=status_err.status_code), | |
request_id) | |
except APIError as api_err: | |
logger.error(f"APIError from ChatGPT: {api_err}\n{traceback.format_exc()}") | |
resp_json = {"error": {"message": api_err.message, "type": api_err.type, "param": api_err.param, | |
"code": api_err.code}} | |
# Error log | |
self.access_logger_queue.put(self.error_item_class( | |
request_id=request_id, | |
exception=api_err, | |
traceback_info=traceback.format_exc(), | |
response_json=resp_json, | |
status_code=502 | |
)) | |
return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id) | |
except OpenAIError as oai_err: | |
logger.error(f"OpenAIError: {oai_err}\n{traceback.format_exc()}") | |
resp_json = {"error": {"message": str(oai_err), "type": "openai_error", "param": None, "code": None}} | |
# Error log | |
self.access_logger_queue.put(self.error_item_class( | |
request_id=request_id, | |
exception=oai_err, | |
traceback_info=traceback.format_exc(), | |
response_json=resp_json, | |
status_code=502 | |
)) | |
return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id) | |
except Exception as ex: | |
logger.error(f"Error at server: {ex}\n{traceback.format_exc()}") | |
resp_json = {"error": {"message": "Proxy error", "type": "proxy_error", "param": None, "code": None}} | |
# Error log | |
self.access_logger_queue.put(self.error_item_class( | |
request_id=request_id, | |
exception=ex, | |
traceback_info=traceback.format_exc(), | |
response_json=resp_json, | |
status_code=502 | |
)) | |
return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id) |