File size: 3,684 Bytes
353256e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json

from websockets.exceptions import ConnectionClosedOK
from websockets.sync.client import connect

from chatglm2_6b.modelClient import ChatGLM2
import abc


class ChatClient(abc.ABC):
    @abc.abstractmethod
    def simple_chat(self, query, history, temperature, top_p):
        pass

    @abc.abstractmethod
    def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
        pass


def format_chat_prompt(message: str, chat_history, instructions: str) -> str:
    instructions = instructions.strip(" ").strip("\n")
    prompt = f"对话背景设定:{instructions}"
    for i, (user_message, bot_message) in enumerate(chat_history):
        prompt = f"{prompt}\n\n[Round {i + 1}]\n\n问:{user_message}\n\n答:{bot_message}"
    prompt = f"{prompt}\n\n[Round {len(chat_history)+1}]\n\n问:{message}\n\n答:"
    return prompt


class ChatGLM2APIClient(ChatClient):
    def __init__(self, ws_url=None):
        self.ws_url = "ws://localhost:10001"
        if ws_url:
            self.ws_url = ws_url

    def simple_chat(self, query, history, temperature, top_p):
        """chatglm2-6b 模型定义的对话方法"""
        url = f"{self.ws_url}/streamChat"
        with connect(url) as websocket:
            msg = json.dumps({
                "query": query, "history": history, 
                "temperature": temperature, "top_p": top_p,
            })
            websocket.send(msg)

            data = None
            try:
                while True:
                    data = websocket.recv()
                    data = json.loads(data)
                    yield data['resp'], data['history']
            except ConnectionClosedOK:
                print("generation is finished")

    def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
        """基于chatglm2-6b text_generate 实现的基于预设指令的对话"""
        url = f"{self.ws_url}/streamGenerate"

        prompt = format_chat_prompt(message, chat_history, instructions)
        chat_history = chat_history + [[message, ""]]
        params = json.dumps({"prompt": prompt, "temperature": temperature, "top_p": top_p})
        with connect(url) as websocket:
            websocket.send(params)

            data = None
            try:
                while True:
                    data = websocket.recv()
                    data = json.loads(data)
                    resp = data['text']

                    last_turn = list(chat_history.pop(-1))
                    last_turn[-1] = resp
                    chat_history = chat_history + [last_turn]
                    yield resp, chat_history
            except ConnectionClosedOK:
                print("generation is finished")


class ChatGLM2ModelClient(ChatClient):
    def __init__(self, model_path=None):
        self.model = ChatGLM2(model_path)

    def simple_chat(self, query, history, temperature, top_p):
        kwargs = {
            "query": query, "history": history, 
            "temperature": temperature, "top_p": top_p,
        }
        for resp, history in self.model.stream_chat(**kwargs):
            yield resp, history

    def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
        prompt = format_chat_prompt(message, chat_history, instructions)
        chat_history = chat_history + [[message, ""]]
        kwargs = {"prompt": prompt, "temperature": temperature, "top_p": top_p}
        for resp in self.model.stream_generate(**kwargs):
            last_turn = list(chat_history.pop(-1))
            last_turn[-1] = resp
            chat_history = chat_history + [last_turn]
            yield resp, chat_history