File size: 5,001 Bytes
6f179e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import json
from jinja2 import Template
from .DataBase import ChromaDB
from .Models import GLM, GLM_api
from .utils import *
class ChatWorld:
def __init__(
self,
pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4",
embedding_model_name_or_path="BAAI/bge-small-zh-v1.5",
global_batch_size=16,
model_load=True,
) -> None:
self.model_name = pretrained_model_name_or_path
self.client = GLM_api()
if model_load:
self.model = GLM()
self.db = ChromaDB(embedding_model_name_or_path)
self.prompt = Template(
(
'Please be aware that your codename in this conversation is "{{model_role_name}}"'
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
"下文给定了一些聊天记录,位于##分隔号中。\n"
"如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n"
"如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n"
"如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n"
"请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n"
"请你永远只以{{model_role_name}}身份,进行任何的回复。\n"
"{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}"
)
)
def setStory(self, **stories_kargs):
self.db.deleteStoriesByMeta(metas=stories_kargs["metas"])
self.db.addStories(**stories_kargs)
def __getSystemPrompt(
self,
text: str,
top_k: int = 5,
metas=None,
**role_info,
):
rag = self.db.searchBySim(text, top_k, metas)
return {
"role": "system",
"content": self.prompt.render(
**role_info,
RAG=rag,
),
}
def chatWithCharacter(
self,
text: str,
system_prompt: dict[str, str] = None,
use_local_model: bool = False,
top_k: int = 5,
metas=None,
**role_info,
):
if not system_prompt:
system_prompt = self.__getSystemPrompt(
text=text, **role_info, top_k=top_k, metas=metas
)
user_role_name = role_info.get("role_name")
if not user_role_name:
raise ValueError("role_name is required")
message = [
system_prompt,
{"role": "user", "content": f"{user_role_name}:「{text}」"},
]
logging_info(f"message: {message}")
if use_local_model:
response = self.model.get_response(message)
else:
response = self.client.chat(message)
return response
def chatWithoutCharacter(
self,
text: str,
system_prompt: dict[str, str] = None,
use_local_model: bool = False,
):
logging_info(f"text: {text}")
message = [
{"role": "user", "content": f"{text}"},
]
if use_local_model:
response = self.model.get_response(text)
else:
response = self.client.chat(message)
return response
def getRoleNameFromFile(self, input_file: str):
# # 读取文件内容
# logging_info(f"file content: {input_file}")
# # 保存文件内容
# input_text_list = input_file.split("\n")
# role_name_set = set()
# # 读取角色名
# for line in input_text_list:
# role_name_set.add(line.split(":")[0])
# role_name_list = [i for i in role_name_set if i != ""]
# logging_info(f"role_name_list: {role_name_list}")
prompt = (
f"{input_file}\n"
+ '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```'
)
respense = self.chatWithoutCharacter(prompt, use_local_model=False)
json_start_index = respense.find("```json")
json_end_index = respense.find("```", json_start_index + 1)
json_str = respense[json_start_index + 7 : json_end_index]
print(json_str)
try:
json_str = json.loads(json_str)
role_name_list = [i["name"] for i in json_str]
role_name_dict = {i["name"]: i["nickname"] for i in json_str}
except Exception as e:
print(e)
role_name_list = []
role_name_dict = {}
return role_name_list, role_name_dict
|