File size: 4,208 Bytes
da34456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from typing import TypedDict, Literal, List, Optional, Tuple, Iterator


#### data types #########
# 下面的数据类型定义与CharacterGLM API一致,但与modeling_chatglm.py的chat方法不一致
# 参考 https://open.bigmodel.cn/dev/api#characterglm
RoleType = Literal["user", "assistant"]

class Msg(TypedDict):
    role: RoleType
    content: str


class SessionMeta(TypedDict):
    user_name: str
    bot_name: str
    bot_info: str
    user_info: Optional[str]


HistoryType = List[Msg]


class CharacterGLMGenerationUtils:
    @staticmethod
    def convert_chatglm_history_to_characterglm_history(user_query: str, history: List[Tuple[str, str]]) -> HistoryType:
        characterglm_history: HistoryType = []
        for i, (query, response) in enumerate(history):
            if i == 0 and query == '':
                # first empty query is an placeholder
                pass
            else:
                characterglm_history.append({
                    "role": "user",
                    "content": query
                })
            characterglm_history.append({
                "role": "assistant",
                "content": response
            })

        characterglm_history.append({
            "role": "user",
            "content": user_query
        })
        return characterglm_history

    @staticmethod
    def build_inputs(session_meta: SessionMeta, history: HistoryType) -> str:
        """ 
        注意:这里假设history最后一条消息是用户query
        """
        texts = []
        texts.append(
            f"以下是一段{session_meta['bot_name']}{session_meta['user_name']}之间的对话。")
        if session_meta.get("bot_info"):
            texts.append(f"关于{session_meta['bot_name']}的信息:{session_meta['bot_info']}")
        if session_meta.get("user_info"):
            texts.append(
                f"关于{session_meta['user_name']}的信息:{session_meta['user_info']}")

        assert history and history[-1]['role'] == 'user'
        for msg in history:
            name = session_meta['user_name'] if msg['role'] == 'user' else session_meta['bot_name']
            texts.append(f"[{name}]" + msg['content'].strip())

        texts = [text.replace('\n', ' ') for text in texts]
        texts.append(f"[{session_meta['bot_name']}]")
        return '\n'.join(texts)


class CharacterGLMAPI:
    @staticmethod
    def build_api_arguments(session_meta: SessionMeta, history: HistoryType) -> dict:
        return {
            "model": "characterglm",
            "meta": session_meta,
            "prompt": history
        }

    @classmethod
    def async_invoke(cls, session_meta: SessionMeta, history: HistoryType):
        """
        注意:
        1. 先设置zhipuai.api_key
        2. 建议传入`return_type='text'`,否则返回结果是json字符串
        
        参考:
            https://open.bigmodel.cn/dev/api#characterglm
        """
        import zhipuai
        kwargs = cls.build_api_arguments(session_meta, history)
        return zhipuai.model_api.async_invoke(**kwargs, return_type='text')
    
    @classmethod
    def invoke(cls, session_meta: SessionMeta, history: HistoryType):
        """ 
        注意:
        1. 先设置zhipuai.api_key
        2. 建议传入`return_type='text'`,否则返回结果是json字符串
        3. 需要再次调用`zhipuai.model_api.query_async_invoke_result`才能获取生成结果
        
        参考:
            https://open.bigmodel.cn/dev/api#characterglm
        """
        import zhipuai
        kwargs = cls.build_api_arguments(session_meta, history)
        return zhipuai.model_api.invoke(**kwargs, return_type='text')
    
    @classmethod
    def generate(cls, session_meta: SessionMeta, history: HistoryType) -> str:
        result = cls.invoke(session_meta, history)
        if not result['success']:
            raise RuntimeError(result)
        return result['data']['choices'][0]['content']
    
    @classmethod
    def stream_generate(cls, session_meta: SessionMeta, history: HistoryType) -> Iterator[str]:
        # 伪流式生成
        return iter(cls.generate(session_meta, history))