File size: 5,001 Bytes
6f179e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
from jinja2 import Template
from .DataBase import ChromaDB

from .Models import GLM, GLM_api

from .utils import *


class ChatWorld:
    def __init__(
        self,
        pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4",
        embedding_model_name_or_path="BAAI/bge-small-zh-v1.5",
        global_batch_size=16,
        model_load=True,
    ) -> None:
        self.model_name = pretrained_model_name_or_path

        self.client = GLM_api()

        if model_load:
            self.model = GLM()

        self.db = ChromaDB(embedding_model_name_or_path)
        self.prompt = Template(
            (
                'Please be aware that your codename in this conversation is "{{model_role_name}}"'
                '{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
                "下文给定了一些聊天记录,位于##分隔号中。\n"
                "如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n"
                "如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n"
                "如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n"
                "请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n"
                "请你永远只以{{model_role_name}}身份,进行任何的回复。\n"
                "{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}"
            )
        )

    def setStory(self, **stories_kargs):
        self.db.deleteStoriesByMeta(metas=stories_kargs["metas"])
        self.db.addStories(**stories_kargs)

    def __getSystemPrompt(
        self,
        text: str,
        top_k: int = 5,
        metas=None,
        **role_info,
    ):

        rag = self.db.searchBySim(text, top_k, metas)

        return {
            "role": "system",
            "content": self.prompt.render(
                **role_info,
                RAG=rag,
            ),
        }

    def chatWithCharacter(
        self,
        text: str,
        system_prompt: dict[str, str] = None,
        use_local_model: bool = False,
        top_k: int = 5,
        metas=None,
        **role_info,
    ):

        if not system_prompt:
            system_prompt = self.__getSystemPrompt(
                text=text, **role_info, top_k=top_k, metas=metas
            )

        user_role_name = role_info.get("role_name")

        if not user_role_name:
            raise ValueError("role_name is required")

        message = [
            system_prompt,
            {"role": "user", "content": f"{user_role_name}:「{text}」"},
        ]

        logging_info(f"message: {message}")

        if use_local_model:
            response = self.model.get_response(message)
        else:
            response = self.client.chat(message)

        return response

    def chatWithoutCharacter(
        self,
        text: str,
        system_prompt: dict[str, str] = None,
        use_local_model: bool = False,
    ):

        logging_info(f"text: {text}")

        message = [
            {"role": "user", "content": f"{text}"},
        ]

        if use_local_model:
            response = self.model.get_response(text)
        else:

            response = self.client.chat(message)

        return response

    def getRoleNameFromFile(self, input_file: str):
        # # 读取文件内容
        # logging_info(f"file content: {input_file}")

        # # 保存文件内容
        # input_text_list = input_file.split("\n")
        # role_name_set = set()

        # # 读取角色名
        # for line in input_text_list:
        #     role_name_set.add(line.split(":")[0])

        # role_name_list = [i for i in role_name_set if i != ""]
        # logging_info(f"role_name_list: {role_name_list}")

        prompt = (
            f"{input_file}\n"
            + '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```'
        )

        respense = self.chatWithoutCharacter(prompt, use_local_model=False)

        json_start_index = respense.find("```json")
        json_end_index = respense.find("```", json_start_index + 1)

        json_str = respense[json_start_index + 7 : json_end_index]

        print(json_str)

        try:
            json_str = json.loads(json_str)
            role_name_list = [i["name"] for i in json_str]
            role_name_dict = {i["name"]: i["nickname"] for i in json_str}
        except Exception as e:
            print(e)
            role_name_list = []
            role_name_dict = {}

        return role_name_list, role_name_dict