ChatWorld / src /ChatWorld.py
JiangYH's picture
Upload folder using huggingface_hub
4ab98db verified
import json
from jinja2 import Template
from .DataBase import ChromaDB
from .Models import GLM, GLM_api
from .utils import *
class ChatWorld:
def __init__(
self,
pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4",
embedding_model_name_or_path="BAAI/bge-small-zh-v1.5",
global_batch_size=16,
model_load=True,
) -> None:
self.model_name = pretrained_model_name_or_path
self.client = GLM_api()
if model_load:
self.model = GLM()
self.db = ChromaDB(embedding_model_name_or_path)
self.prompt = Template(
(
'Please be aware that your codename in this conversation is "{{model_role_name}}"'
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
"下文给定了一些聊天记录,位于##分隔号中。\n"
"如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n"
"如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n"
"如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n"
"请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n"
"请你永远只以{{model_role_name}}身份,进行任何的回复。\n"
"{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}"
)
)
def setStory(self, **stories_kargs):
self.db.deleteStoriesByMeta(metas=stories_kargs["metas"])
self.db.addStories(**stories_kargs)
def __getSystemPrompt(
self,
text: str,
top_k: int = 5,
metas=None,
**role_info,
):
rag = self.db.searchBySim(text, top_k, metas)
return {
"role": "system",
"content": self.prompt.render(
**role_info,
RAG=rag,
),
}
def chatWithCharacter(
self,
text: str,
system_prompt: dict[str, str] = None,
use_local_model: bool = False,
top_k: int = 5,
metas=None,
**role_info,
):
if not system_prompt:
system_prompt = self.__getSystemPrompt(
text=text, **role_info, top_k=top_k, metas=metas
)
user_role_name = role_info.get("role_name")
if not user_role_name:
raise ValueError("role_name is required")
message = [
system_prompt,
{"role": "user", "content": f"{user_role_name}:「{text}」"},
]
logging_info(f"message: {message}")
if use_local_model:
response = self.model.get_response(message)
else:
response = self.client.chat(message)
return response
def chatWithoutCharacter(
self,
text: str,
system_prompt: dict[str, str] = None,
use_local_model: bool = False,
):
logging_info(f"text: {text}")
message = [
{"role": "user", "content": f"{text}"},
]
if use_local_model:
response = self.model.get_response(text)
else:
response = self.client.chat(message)
return response
def getRoleNameFromFile(self, input_file: str):
# # 读取文件内容
# logging_info(f"file content: {input_file}")
# # 保存文件内容
# input_text_list = input_file.split("\n")
# role_name_set = set()
# # 读取角色名
# for line in input_text_list:
# role_name_set.add(line.split(":")[0])
# role_name_list = [i for i in role_name_set if i != ""]
# logging_info(f"role_name_list: {role_name_list}")
prompt = (
f"{input_file}\n"
+ '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```'
)
respense = self.chatWithoutCharacter(prompt, use_local_model=False)
json_start_index = respense.find("```json")
json_end_index = respense.find("```", json_start_index + 1)
json_str = respense[json_start_index + 7 : json_end_index]
print(json_str)
try:
json_str = json.loads(json_str)
role_name_list = [i["name"] for i in json_str]
role_name_dict = {i["name"]: i["nickname"] for i in json_str}
except Exception as e:
print(e)
role_name_list = []
role_name_dict = {}
return role_name_list, role_name_dict