Spaces:
Runtime error
Runtime error
import base64 | |
import datetime | |
import json | |
import os | |
from io import BytesIO | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from loguru import logger | |
from zhipuai import ZhipuAI | |
from src import shared, config | |
from src.base_model import BaseLLMModel | |
from src.presets import ( | |
INITIAL_SYSTEM_PROMPT, | |
TIMEOUT_ALL, | |
TIMEOUT_STREAMING, | |
STANDARD_ERROR_MSG, | |
CONNECTION_TIMEOUT_MSG, | |
READ_TIMEOUT_MSG, | |
ERROR_RETRIEVE_MSG, | |
GENERAL_ERROR_MSG, | |
CHAT_COMPLETION_URL, | |
SUMMARY_CHAT_SYSTEM_PROMPT | |
) | |
from src.openai_client import OpenAIClient | |
from src.utils import ( | |
count_token, | |
construct_system, | |
construct_user, | |
get_last_day_of_month, | |
i18n, | |
replace_special_symbols, | |
) | |
def decode_chat_response(response): | |
try: | |
error_msg = "" | |
for chunk in response: | |
if chunk: | |
# chunk = chunk.decode() | |
chunk = chunk.choices[0].delta | |
chunk_length = len(chunk.content) | |
try: | |
if chunk_length > 1 and chunk!="": | |
try: | |
yield chunk.content | |
except Exception as e: | |
logger.error(f"Error xxx: {e}") | |
continue | |
except Exception as ee: | |
logger.error(f"ERROR: {chunk}, {ee}") | |
continue | |
if error_msg and not error_msg.endswith("[DONE]"): | |
raise Exception(error_msg) | |
except GeneratorExit as ge: | |
raise ValueError(f"GeneratorExit: {ge}") | |
except Exception as e: | |
raise Exception(f"Error in generate: {str(e)}") | |
class ZhipuAIClient(OpenAIClient): | |
def __init__( | |
self, | |
model_name, | |
api_key, | |
system_prompt=INITIAL_SYSTEM_PROMPT, | |
temperature=1.0, | |
top_p=1.0, | |
user_name="", | |
) -> None: | |
super().__init__( | |
api_key = api_key, | |
model_name=model_name, | |
temperature=temperature, | |
top_p=top_p, | |
system_prompt=system_prompt, | |
# user=user_name, | |
) | |
self.api_key = api_key | |
self.need_api_key = True | |
self._refresh_header() | |
self.client = None | |
# self.user_name = user_name | |
logger.info(f"user name: {user_name}") | |
def get_answer_stream_iter(self): | |
if not self.api_key: | |
raise ValueError("API key is not set") | |
response = self._get_response(stream=True) | |
if response is not None: | |
stream_iter = decode_chat_response(response) | |
partial_text = "" | |
for chunk in stream_iter: | |
partial_text += chunk | |
yield partial_text | |
else: | |
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG | |
# def get_answer_at_once(self): | |
# if not self.api_key: | |
# raise ValueError("API key is not set") | |
# response = self._get_response() | |
# response = json.loads(response.text) | |
# content = response["choices"][0]["message"]["content"] | |
# total_token_count = response["usage"]["total_tokens"] | |
# return content, total_token_count | |
# 在不开启多账号模式的时候,这个装饰器不会起作用 | |
def _get_response(self, stream=False): | |
zhipuai_api_key = self.api_key | |
system_prompt = self.system_prompt | |
history = self.history | |
# logger.debug(f"{history}") | |
# headers = { | |
# "Authorization": f"Bearer {zhipuai_api_key}", | |
# "Content-Type": "application/json", | |
# } | |
if system_prompt is not None: | |
history = [construct_system(system_prompt), *history] | |
payload = { | |
"model": self.model_name, | |
"messages": history, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"n": self.n_choices, | |
"stream": stream, | |
"presence_penalty": self.presence_penalty, | |
"frequency_penalty": self.frequency_penalty, | |
} | |
if self.max_generation_token is not None: | |
payload["max_tokens"] = self.max_generation_token | |
if self.stop_sequence is not None: | |
payload["stop"] = self.stop_sequence | |
if self.logit_bias is not None: | |
payload["logit_bias"] = self.logit_bias | |
if self.user_identifier: | |
payload["user"] = self.user_identifier | |
if stream: | |
timeout = TIMEOUT_STREAMING | |
else: | |
timeout = TIMEOUT_ALL | |
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求 | |
# if shared.state.chat_completion_url != CHAT_COMPLETION_URL: | |
# logger.debug(f"使用自定义API URL: {shared.state.chat_completion_url}") | |
# with config.retrieve_proxy(): | |
# try: | |
# response = requests.post( | |
# shared.state.chat_completion_url, | |
# headers=headers, | |
# json=payload, | |
# stream=stream, | |
# timeout=timeout, | |
# ) | |
# except Exception as e: | |
# logger.error(f"Error: {e}") | |
# response = None | |
# return response | |
if self.client is None: | |
self.client = ZhipuAI(api_key = zhipuai_api_key) | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
# model="glm-3-turbo", | |
messages=history, | |
temperature=self.temperature, | |
top_p= self.top_p, | |
stream= stream, | |
) | |
# "n": self.n_choices, | |
# "stream": stream, | |
# "presence_penalty": self.presence_penalty, | |
# "frequency_penalty": self.frequency_penalty, | |
return response | |
# todo: fix bug | |
def billing_info(self): | |
status_text = "获取API使用情况失败,未更新ZhipuAI代价代码。" | |
return status_text | |
# try: | |
# curr_time = datetime.datetime.now() | |
# last_day_of_month = get_last_day_of_month( | |
# curr_time).strftime("%Y-%m-%d") | |
# first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d") | |
# usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}" | |
# try: | |
# usage_data = self._get_billing_data(usage_url) | |
# except Exception as e: | |
# logger.warning(f"获取API使用情况失败:" + str(e)) | |
# return i18n("**获取API使用情况失败**") | |
# rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100) | |
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}" | |
# except requests.exceptions.ConnectTimeout: | |
# status_text = ( | |
# STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG | |
# ) | |
# return status_text | |
# except requests.exceptions.ReadTimeout: | |
# status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG | |
# return status_text | |
# except Exception as e: | |
# import traceback | |
# traceback.print_exc() | |
# logger.error(i18n("获取API使用情况失败:") + str(e)) | |
# return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG |