|
from __future__ import annotations |
|
from typing import TYPE_CHECKING, List |
|
|
|
import logging |
|
import json |
|
import commentjson as cjson |
|
import os |
|
import sys |
|
import requests |
|
import urllib3 |
|
import platform |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
from tqdm import tqdm |
|
import colorama |
|
from duckduckgo_search import ddg |
|
import asyncio |
|
import aiohttp |
|
from enum import Enum |
|
import uuid |
|
|
|
from .presets import * |
|
from .llama_func import * |
|
from .utils import * |
|
from . import shared |
|
from .config import retrieve_proxy |
|
from modules import config |
|
from .base_model import BaseLLMModel, ModelType |
|
|
|
|
|
class OpenAIClient(BaseLLMModel): |
|
def __init__( |
|
self, |
|
model_name, |
|
api_key, |
|
system_prompt=INITIAL_SYSTEM_PROMPT, |
|
temperature=1.0, |
|
top_p=1.0, |
|
) -> None: |
|
super().__init__( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
system_prompt=system_prompt, |
|
) |
|
self.api_key = api_key |
|
self.need_api_key = True |
|
self._refresh_header() |
|
|
|
def get_answer_stream_iter(self): |
|
response = self._get_response(stream=True) |
|
if response is not None: |
|
iter = self._decode_chat_response(response) |
|
partial_text = "" |
|
for i in iter: |
|
partial_text += i |
|
yield partial_text |
|
else: |
|
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG |
|
|
|
def get_answer_at_once(self): |
|
response = self._get_response() |
|
response = json.loads(response.text) |
|
content = response["choices"][0]["message"]["content"] |
|
total_token_count = response["usage"]["total_tokens"] |
|
return content, total_token_count |
|
|
|
def count_token(self, user_input): |
|
input_token_count = count_token(construct_user(user_input)) |
|
if self.system_prompt is not None and len(self.all_token_counts) == 0: |
|
system_prompt_token_count = count_token( |
|
construct_system(self.system_prompt) |
|
) |
|
return input_token_count + system_prompt_token_count |
|
return input_token_count |
|
|
|
def billing_info(self): |
|
try: |
|
curr_time = datetime.datetime.now() |
|
last_day_of_month = get_last_day_of_month( |
|
curr_time).strftime("%Y-%m-%d") |
|
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d") |
|
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}" |
|
try: |
|
usage_data = self._get_billing_data(usage_url) |
|
except Exception as e: |
|
logging.error(f"获取API使用情况失败:" + str(e)) |
|
return i18n("**获取API使用情况失败**") |
|
rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100) |
|
return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}" |
|
except requests.exceptions.ConnectTimeout: |
|
status_text = ( |
|
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG |
|
) |
|
return status_text |
|
except requests.exceptions.ReadTimeout: |
|
status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG |
|
return status_text |
|
except Exception as e: |
|
import traceback |
|
traceback.print_exc() |
|
logging.error(i18n("获取API使用情况失败:") + str(e)) |
|
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG |
|
|
|
def set_token_upper_limit(self, new_upper_limit): |
|
pass |
|
|
|
@shared.state.switching_api_key |
|
def _get_response(self, stream=False): |
|
openai_api_key = self.api_key |
|
system_prompt = self.system_prompt |
|
history = self.history |
|
logging.debug(colorama.Fore.YELLOW + |
|
f"{history}" + colorama.Fore.RESET) |
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {openai_api_key}", |
|
} |
|
|
|
if system_prompt is not None: |
|
history = [construct_system(system_prompt), *history] |
|
|
|
payload = { |
|
"model": self.model_name, |
|
"messages": history, |
|
"temperature": self.temperature, |
|
"top_p": self.top_p, |
|
"n": self.n_choices, |
|
"stream": stream, |
|
"presence_penalty": self.presence_penalty, |
|
"frequency_penalty": self.frequency_penalty, |
|
} |
|
|
|
if self.max_generation_token is not None: |
|
payload["max_tokens"] = self.max_generation_token |
|
if self.stop_sequence is not None: |
|
payload["stop"] = self.stop_sequence |
|
if self.logit_bias is not None: |
|
payload["logit_bias"] = self.logit_bias |
|
if self.user_identifier is not None: |
|
payload["user"] = self.user_identifier |
|
|
|
if stream: |
|
timeout = TIMEOUT_STREAMING |
|
else: |
|
timeout = TIMEOUT_ALL |
|
|
|
|
|
if shared.state.completion_url != COMPLETION_URL: |
|
logging.info(f"使用自定义API URL: {shared.state.completion_url}") |
|
|
|
with retrieve_proxy(): |
|
try: |
|
response = requests.post( |
|
shared.state.completion_url, |
|
headers=headers, |
|
json=payload, |
|
stream=stream, |
|
timeout=timeout, |
|
) |
|
except: |
|
return None |
|
return response |
|
|
|
def _refresh_header(self): |
|
self.headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {self.api_key}", |
|
} |
|
|
|
def _get_billing_data(self, billing_url): |
|
with retrieve_proxy(): |
|
response = requests.get( |
|
billing_url, |
|
headers=self.headers, |
|
timeout=TIMEOUT_ALL, |
|
) |
|
|
|
if response.status_code == 200: |
|
data = response.json() |
|
return data |
|
else: |
|
raise Exception( |
|
f"API request failed with status code {response.status_code}: {response.text}" |
|
) |
|
|
|
def _decode_chat_response(self, response): |
|
error_msg = "" |
|
for chunk in response.iter_lines(): |
|
if chunk: |
|
chunk = chunk.decode() |
|
chunk_length = len(chunk) |
|
try: |
|
chunk = json.loads(chunk[6:]) |
|
except json.JSONDecodeError: |
|
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}") |
|
error_msg += chunk |
|
continue |
|
if chunk_length > 6 and "delta" in chunk["choices"][0]: |
|
if chunk["choices"][0]["finish_reason"] == "stop": |
|
break |
|
try: |
|
yield chunk["choices"][0]["delta"]["content"] |
|
except Exception as e: |
|
|
|
continue |
|
if error_msg: |
|
raise Exception(error_msg) |
|
|
|
def set_key(self, new_access_key): |
|
ret = super().set_key(new_access_key) |
|
self._refresh_header() |
|
return ret |
|
|
|
|
|
class ChatGLM_Client(BaseLLMModel): |
|
def __init__(self, model_name) -> None: |
|
super().__init__(model_name=model_name) |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
global CHATGLM_TOKENIZER, CHATGLM_MODEL |
|
if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None: |
|
system_name = platform.system() |
|
model_path = None |
|
if os.path.exists("models"): |
|
model_dirs = os.listdir("models") |
|
if model_name in model_dirs: |
|
model_path = f"models/{model_name}" |
|
if model_path is not None: |
|
model_source = model_path |
|
else: |
|
model_source = f"THUDM/{model_name}" |
|
CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained( |
|
model_source, trust_remote_code=True |
|
) |
|
quantified = False |
|
if "int4" in model_name: |
|
quantified = True |
|
model = AutoModel.from_pretrained( |
|
model_source, trust_remote_code=True |
|
) |
|
if torch.cuda.is_available(): |
|
|
|
logging.info("CUDA is available, using CUDA") |
|
model = model.half().cuda() |
|
|
|
elif system_name == "Darwin" and model_path is not None and not quantified: |
|
logging.info("Running on macOS, using MPS") |
|
|
|
model = model.half().to("mps") |
|
else: |
|
logging.info("GPU is not available, using CPU") |
|
model = model.float() |
|
model = model.eval() |
|
CHATGLM_MODEL = model |
|
|
|
def _get_glm_style_input(self): |
|
history = [x["content"] for x in self.history] |
|
query = history.pop() |
|
logging.debug(colorama.Fore.YELLOW + |
|
f"{history}" + colorama.Fore.RESET) |
|
assert ( |
|
len(history) % 2 == 0 |
|
), f"History should be even length. current history is: {history}" |
|
history = [[history[i], history[i + 1]] |
|
for i in range(0, len(history), 2)] |
|
return history, query |
|
|
|
def get_answer_at_once(self): |
|
history, query = self._get_glm_style_input() |
|
response, _ = CHATGLM_MODEL.chat( |
|
CHATGLM_TOKENIZER, query, history=history) |
|
return response, len(response) |
|
|
|
def get_answer_stream_iter(self): |
|
history, query = self._get_glm_style_input() |
|
for response, history in CHATGLM_MODEL.stream_chat( |
|
CHATGLM_TOKENIZER, |
|
query, |
|
history, |
|
max_length=self.token_upper_limit, |
|
top_p=self.top_p, |
|
temperature=self.temperature, |
|
): |
|
yield response |
|
|
|
|
|
class LLaMA_Client(BaseLLMModel): |
|
def __init__( |
|
self, |
|
model_name, |
|
lora_path=None, |
|
) -> None: |
|
super().__init__(model_name=model_name) |
|
from lmflow.datasets.dataset import Dataset |
|
from lmflow.pipeline.auto_pipeline import AutoPipeline |
|
from lmflow.models.auto_model import AutoModel |
|
from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments |
|
|
|
self.max_generation_token = 1000 |
|
self.end_string = "\n\n" |
|
|
|
data_args = DatasetArguments(dataset_path=None) |
|
self.dataset = Dataset(data_args) |
|
self.system_prompt = "" |
|
|
|
global LLAMA_MODEL, LLAMA_INFERENCER |
|
if LLAMA_MODEL is None or LLAMA_INFERENCER is None: |
|
model_path = None |
|
if os.path.exists("models"): |
|
model_dirs = os.listdir("models") |
|
if model_name in model_dirs: |
|
model_path = f"models/{model_name}" |
|
if model_path is not None: |
|
model_source = model_path |
|
else: |
|
model_source = f"decapoda-research/{model_name}" |
|
|
|
if lora_path is not None: |
|
lora_path = f"lora/{lora_path}" |
|
model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, |
|
use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True) |
|
pipeline_args = InferencerArguments( |
|
local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16') |
|
|
|
with open(pipeline_args.deepspeed, "r") as f: |
|
ds_config = json.load(f) |
|
LLAMA_MODEL = AutoModel.get_model( |
|
model_args, |
|
tune_strategy="none", |
|
ds_config=ds_config, |
|
) |
|
LLAMA_INFERENCER = AutoPipeline.get_pipeline( |
|
pipeline_name="inferencer", |
|
model_args=model_args, |
|
data_args=data_args, |
|
pipeline_args=pipeline_args, |
|
) |
|
|
|
def _get_llama_style_input(self): |
|
history = [] |
|
instruction = "" |
|
if self.system_prompt: |
|
instruction = (f"Instruction: {self.system_prompt}\n") |
|
for x in self.history: |
|
if x["role"] == "user": |
|
history.append(f"{instruction}Input: {x['content']}") |
|
else: |
|
history.append(f"Output: {x['content']}") |
|
context = "\n\n".join(history) |
|
context += "\n\nOutput: " |
|
return context |
|
|
|
def get_answer_at_once(self): |
|
context = self._get_llama_style_input() |
|
|
|
input_dataset = self.dataset.from_dict( |
|
{"type": "text_only", "instances": [{"text": context}]} |
|
) |
|
|
|
output_dataset = LLAMA_INFERENCER.inference( |
|
model=LLAMA_MODEL, |
|
dataset=input_dataset, |
|
max_new_tokens=self.max_generation_token, |
|
temperature=self.temperature, |
|
) |
|
|
|
response = output_dataset.to_dict()["instances"][0]["text"] |
|
return response, len(response) |
|
|
|
def get_answer_stream_iter(self): |
|
context = self._get_llama_style_input() |
|
partial_text = "" |
|
step = 1 |
|
for _ in range(0, self.max_generation_token, step): |
|
input_dataset = self.dataset.from_dict( |
|
{"type": "text_only", "instances": [ |
|
{"text": context + partial_text}]} |
|
) |
|
output_dataset = LLAMA_INFERENCER.inference( |
|
model=LLAMA_MODEL, |
|
dataset=input_dataset, |
|
max_new_tokens=step, |
|
temperature=self.temperature, |
|
) |
|
response = output_dataset.to_dict()["instances"][0]["text"] |
|
if response == "" or response == self.end_string: |
|
break |
|
partial_text += response |
|
yield partial_text |
|
|
|
|
|
class XMChat(BaseLLMModel): |
|
def __init__(self, api_key): |
|
super().__init__(model_name="xmchat") |
|
self.api_key = api_key |
|
self.session_id = None |
|
self.reset() |
|
self.image_bytes = None |
|
self.image_path = None |
|
self.xm_history = [] |
|
self.url = "https://xmbot.net/web" |
|
self.last_conv_id = None |
|
|
|
def reset(self): |
|
self.session_id = str(uuid.uuid4()) |
|
self.last_conv_id = None |
|
return [], "已重置" |
|
|
|
def image_to_base64(self, image_path): |
|
|
|
img = Image.open(image_path) |
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
max_dimension = 2048 |
|
scale_ratio = min(max_dimension / width, max_dimension / height) |
|
|
|
if scale_ratio < 1: |
|
|
|
new_width = int(width * scale_ratio) |
|
new_height = int(height * scale_ratio) |
|
img = img.resize((new_width, new_height), Image.ANTIALIAS) |
|
|
|
|
|
buffer = BytesIO() |
|
if img.mode == "RGBA": |
|
img = img.convert("RGB") |
|
img.save(buffer, format='JPEG') |
|
binary_image = buffer.getvalue() |
|
|
|
|
|
base64_image = base64.b64encode(binary_image).decode('utf-8') |
|
|
|
return base64_image |
|
|
|
def try_read_image(self, filepath): |
|
def is_image_file(filepath): |
|
|
|
valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"] |
|
file_extension = os.path.splitext(filepath)[1].lower() |
|
return file_extension in valid_image_extensions |
|
|
|
if is_image_file(filepath): |
|
logging.info(f"读取图片文件: {filepath}") |
|
self.image_bytes = self.image_to_base64(filepath) |
|
self.image_path = filepath |
|
else: |
|
self.image_bytes = None |
|
self.image_path = None |
|
|
|
def like(self): |
|
if self.last_conv_id is None: |
|
return "点赞失败,你还没发送过消息" |
|
data = { |
|
"uuid": self.last_conv_id, |
|
"appraise": "good" |
|
} |
|
response = requests.post(self.url, json=data) |
|
return "👍点赞成功,,感谢反馈~" |
|
|
|
def dislike(self): |
|
if self.last_conv_id is None: |
|
return "点踩失败,你还没发送过消息" |
|
data = { |
|
"uuid": self.last_conv_id, |
|
"appraise": "bad" |
|
} |
|
response = requests.post(self.url, json=data) |
|
return "👎点踩成功,感谢反馈~" |
|
|
|
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot): |
|
fake_inputs = real_inputs |
|
display_append = "" |
|
limited_context = False |
|
return limited_context, fake_inputs, display_append, real_inputs, chatbot |
|
|
|
def handle_file_upload(self, files, chatbot): |
|
"""if the model accepts multi modal input, implement this function""" |
|
if files: |
|
for file in files: |
|
if file.name: |
|
logging.info(f"尝试读取图像: {file.name}") |
|
self.try_read_image(file.name) |
|
if self.image_path is not None: |
|
chatbot = chatbot + [((self.image_path,), None)] |
|
if self.image_bytes is not None: |
|
logging.info("使用图片作为输入") |
|
|
|
self.reset() |
|
conv_id = str(uuid.uuid4()) |
|
data = { |
|
"user_id": self.api_key, |
|
"session_id": self.session_id, |
|
"uuid": conv_id, |
|
"data_type": "imgbase64", |
|
"data": self.image_bytes |
|
} |
|
response = requests.post(self.url, json=data) |
|
response = json.loads(response.text) |
|
logging.info(f"图片回复: {response['data']}") |
|
return None, chatbot, None |
|
|
|
def get_answer_at_once(self): |
|
question = self.history[-1]["content"] |
|
conv_id = str(uuid.uuid4()) |
|
self.last_conv_id = conv_id |
|
data = { |
|
"user_id": self.api_key, |
|
"session_id": self.session_id, |
|
"uuid": conv_id, |
|
"data_type": "text", |
|
"data": question |
|
} |
|
response = requests.post(self.url, json=data) |
|
try: |
|
response = json.loads(response.text) |
|
return response["data"], len(response["data"]) |
|
except Exception as e: |
|
return response.text, len(response.text) |
|
|
|
|
|
|
|
|
|
def get_model( |
|
model_name, |
|
lora_model_path=None, |
|
access_key=None, |
|
temperature=None, |
|
top_p=None, |
|
system_prompt=None, |
|
) -> BaseLLMModel: |
|
msg = i18n("模型设置为了:") + f" {model_name}" |
|
model_type = ModelType.get_type(model_name) |
|
lora_selector_visibility = False |
|
lora_choices = [] |
|
dont_change_lora_selector = False |
|
if model_type != ModelType.OpenAI: |
|
config.local_embedding = True |
|
|
|
model = None |
|
try: |
|
if model_type == ModelType.OpenAI: |
|
logging.info(f"正在加载OpenAI模型: {model_name}") |
|
model = OpenAIClient( |
|
model_name=model_name, |
|
api_key=access_key, |
|
system_prompt=system_prompt, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
elif model_type == ModelType.ChatGLM: |
|
logging.info(f"正在加载ChatGLM模型: {model_name}") |
|
model = ChatGLM_Client(model_name) |
|
elif model_type == ModelType.LLaMA and lora_model_path == "": |
|
msg = f"现在请为 {model_name} 选择LoRA模型" |
|
logging.info(msg) |
|
lora_selector_visibility = True |
|
if os.path.isdir("lora"): |
|
lora_choices = get_file_names( |
|
"lora", plain=True, filetypes=[""]) |
|
lora_choices = ["No LoRA"] + lora_choices |
|
elif model_type == ModelType.LLaMA and lora_model_path != "": |
|
logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}") |
|
dont_change_lora_selector = True |
|
if lora_model_path == "No LoRA": |
|
lora_model_path = None |
|
msg += " + No LoRA" |
|
else: |
|
msg += f" + {lora_model_path}" |
|
model = LLaMA_Client(model_name, lora_model_path) |
|
elif model_type == ModelType.XMChat: |
|
if os.environ.get("XMCHAT_API_KEY") != "": |
|
access_key = os.environ.get("XMCHAT_API_KEY") |
|
model = XMChat(api_key=access_key) |
|
elif model_type == ModelType.Unknown: |
|
raise ValueError(f"未知模型: {model_name}") |
|
logging.info(msg) |
|
except Exception as e: |
|
logging.error(e) |
|
msg = f"{STANDARD_ERROR_MSG}: {e}" |
|
if dont_change_lora_selector: |
|
return model, msg |
|
else: |
|
return model, msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility) |
|
|
|
|
|
if __name__ == "__main__": |
|
with open("config.json", "r") as f: |
|
openai_api_key = cjson.load(f)["openai_api_key"] |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
client = get_model(model_name="chatglm-6b-int4") |
|
chatbot = [] |
|
stream = False |
|
|
|
logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET) |
|
logging.info(client.billing_info()) |
|
|
|
logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET) |
|
question = "巴黎是中国的首都吗?" |
|
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream): |
|
logging.info(i) |
|
logging.info(f"测试问答后history : {client.history}") |
|
|
|
logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET) |
|
question = "我刚刚问了你什么问题?" |
|
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream): |
|
logging.info(i) |
|
logging.info(f"测试记忆力后history : {client.history}") |
|
|
|
logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET) |
|
for i in client.retry(chatbot=chatbot, stream=stream): |
|
logging.info(i) |
|
logging.info(f"重试后history : {client.history}") |
|
|
|
|
|
|
|
|
|
|
|
|