File size: 8,905 Bytes
0bae6cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from __future__ import annotations
import logging
import os
import colorama
import commentjson as cjson
from modules import config
from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel, ModelType
def get_model(
model_name,
lora_model_path=None,
access_key=None,
temperature=None,
top_p=None,
system_prompt=None,
user_name="",
original_model = None,
common_model=None,
common_tokenizer=None
) -> BaseLLMModel:
msg = i18n("模型设置为了:") + f" {model_name}"
model_type = ModelType.get_type(model_name)
lora_selector_visibility = False
lora_choices = ["No LoRA"]
dont_change_lora_selector = False
if model_type != ModelType.OpenAI:
config.local_embedding = True
# del current_model.model
model = original_model
chatbot = gr.Chatbot.update(label=model_name)
try:
if model_type == ModelType.OpenAI:
logging.info(f"正在加载OpenAI模型: {model_name}")
from .OpenAI import OpenAIClient
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAIClient(
model_name=model_name,
api_key=access_key,
system_prompt=system_prompt,
user_name=user_name,
)
elif model_type == ModelType.OpenAIInstruct:
logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
from .OpenAIInstruct import OpenAI_Instruct_Client
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAI_Instruct_Client(
model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.OpenAIVision:
logging.info(f"正在加载OpenAI Vision模型: {model_name}")
from .OpenAIVision import OpenAIVisionClient
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAIVisionClient(
model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.ChatGLM:
logging.info(f"正在加载ChatGLM模型: {model_name}")
from .ChatGLM import ChatGLM_Client
model = ChatGLM_Client(model_name, user_name=user_name)
elif model_type == ModelType.LLaMA and lora_model_path == "":
msg = f"现在请为 {model_name} 选择LoRA模型"
logging.info(msg)
lora_selector_visibility = True
if os.path.isdir("lora"):
lora_choices = ["No LoRA"] + get_file_names_by_pinyin("lora", filetypes=[""])
elif model_type == ModelType.LLaMA and lora_model_path != "":
logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
from .LLaMA import LLaMA_Client
dont_change_lora_selector = True
if lora_model_path == "No LoRA":
lora_model_path = None
msg += " + No LoRA"
else:
msg += f" + {lora_model_path}"
model = LLaMA_Client(
model_name, lora_model_path, user_name=user_name)
elif model_type == ModelType.XMChat:
from .XMChat import XMChat
if os.environ.get("XMCHAT_API_KEY") != "":
access_key = os.environ.get("XMCHAT_API_KEY")
model = XMChat(api_key=access_key, user_name=user_name, common_model=common_model, common_tokenizer=common_tokenizer)
elif model_type == ModelType.StableLM:
from .StableLM import StableLM_Client
model = StableLM_Client(model_name, user_name=user_name)
elif model_type == ModelType.MOSS:
from .MOSS import MOSS_Client
model = MOSS_Client(model_name, user_name=user_name)
elif model_type == ModelType.YuanAI:
from .inspurai import Yuan_Client
model = Yuan_Client(model_name, api_key=access_key,
user_name=user_name, system_prompt=system_prompt)
elif model_type == ModelType.Minimax:
from .minimax import MiniMax_Client
if os.environ.get("MINIMAX_API_KEY") != "":
access_key = os.environ.get("MINIMAX_API_KEY")
model = MiniMax_Client(
model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
elif model_type == ModelType.ChuanhuAgent:
from .ChuanhuAgent import ChuanhuAgent_Client
model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
msg = i18n("启用的工具:") + ", ".join([i.name for i in model.tools])
elif model_type == ModelType.GooglePaLM:
from .GooglePaLM import Google_PaLM_Client
access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
model = Google_PaLM_Client(
model_name, access_key, user_name=user_name)
elif model_type == ModelType.LangchainChat:
from .Azure import Azure_OpenAI_Client
model = Azure_OpenAI_Client(model_name, user_name=user_name)
elif model_type == ModelType.Midjourney:
from .midjourney import Midjourney_Client
mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
model = Midjourney_Client(
model_name, mj_proxy_api_secret, user_name=user_name)
elif model_type == ModelType.Spark:
from .spark import Spark_Client
model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
"SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
elif model_type == ModelType.Claude:
from .Claude import Claude_Client
model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
elif model_type == ModelType.Qwen:
from .Qwen import Qwen_Client
model = Qwen_Client(model_name, user_name=user_name)
elif model_type == ModelType.ERNIE:
from .ERNIE import ERNIE_Client
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
elif model_type == ModelType.DALLE3:
from .DALLE3 import OpenAI_DALLE3_Client
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
logging.info(msg)
except Exception as e:
import traceback
traceback.print_exc()
msg = f"{STANDARD_ERROR_MSG}: {e}"
presudo_key = hide_middle_chars(access_key)
if original_model is not None and model is not None:
model.history = original_model.history
model.history_file_path = original_model.history_file_path
if dont_change_lora_selector:
return model, msg, chatbot, gr.update(), access_key, presudo_key
else:
return model, msg, chatbot, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility), access_key, presudo_key
if __name__ == "__main__":
with open("config.json", "r", encoding="utf-8") as f:
openai_api_key = cjson.load(f)["openai_api_key"]
# set logging level to debug
logging.basicConfig(level=logging.DEBUG)
# client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
client = get_model(model_name="chatglm-6b-int4")
chatbot = []
stream = False
# 测试账单功能
logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
logging.info(client.billing_info())
# 测试问答
logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
question = "巴黎是中国的首都吗?"
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
logging.info(i)
logging.info(f"测试问答后history : {client.history}")
# 测试记忆力
logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
question = "我刚刚问了你什么问题?"
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
logging.info(i)
logging.info(f"测试记忆力后history : {client.history}")
# 测试重试功能
logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
for i in client.retry(chatbot=chatbot, stream=stream):
logging.info(i)
logging.info(f"重试后history : {client.history}")
# # 测试总结功能
# print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
# chatbot, msg = client.reduce_token_size(chatbot=chatbot)
# print(chatbot, msg)
# print(f"总结后history: {client.history}")
|