|
import torch |
|
from typing import TypedDict, Literal, List, Optional, Tuple, Iterator |
|
|
|
|
|
|
|
|
|
|
|
RoleType = Literal["user", "assistant"] |
|
|
|
class Msg(TypedDict): |
|
role: RoleType |
|
content: str |
|
|
|
|
|
class SessionMeta(TypedDict): |
|
user_name: str |
|
bot_name: str |
|
bot_info: str |
|
user_info: Optional[str] |
|
|
|
|
|
HistoryType = List[Msg] |
|
|
|
|
|
class CharacterGLMGenerationUtils: |
|
@staticmethod |
|
def convert_chatglm_history_to_characterglm_history(user_query: str, history: List[Tuple[str, str]]) -> HistoryType: |
|
characterglm_history: HistoryType = [] |
|
for i, (query, response) in enumerate(history): |
|
if i == 0 and query == '': |
|
|
|
pass |
|
else: |
|
characterglm_history.append({ |
|
"role": "user", |
|
"content": query |
|
}) |
|
characterglm_history.append({ |
|
"role": "assistant", |
|
"content": response |
|
}) |
|
|
|
characterglm_history.append({ |
|
"role": "user", |
|
"content": user_query |
|
}) |
|
return characterglm_history |
|
|
|
@staticmethod |
|
def build_inputs(session_meta: SessionMeta, history: HistoryType) -> str: |
|
""" |
|
注意:这里假设history最后一条消息是用户query |
|
""" |
|
texts = [] |
|
texts.append( |
|
f"以下是一段{session_meta['bot_name']}和{session_meta['user_name']}之间的对话。") |
|
if session_meta.get("bot_info"): |
|
texts.append(f"关于{session_meta['bot_name']}的信息:{session_meta['bot_info']}") |
|
if session_meta.get("user_info"): |
|
texts.append( |
|
f"关于{session_meta['user_name']}的信息:{session_meta['user_info']}") |
|
|
|
assert history and history[-1]['role'] == 'user' |
|
for msg in history: |
|
name = session_meta['user_name'] if msg['role'] == 'user' else session_meta['bot_name'] |
|
texts.append(f"[{name}]" + msg['content'].strip()) |
|
|
|
texts = [text.replace('\n', ' ') for text in texts] |
|
texts.append(f"[{session_meta['bot_name']}]") |
|
return '\n'.join(texts) |
|
|
|
|
|
class CharacterGLMAPI: |
|
@staticmethod |
|
def build_api_arguments(session_meta: SessionMeta, history: HistoryType) -> dict: |
|
return { |
|
"model": "characterglm", |
|
"meta": session_meta, |
|
"prompt": history |
|
} |
|
|
|
@classmethod |
|
def async_invoke(cls, session_meta: SessionMeta, history: HistoryType): |
|
""" |
|
注意: |
|
1. 先设置zhipuai.api_key |
|
2. 建议传入`return_type='text'`,否则返回结果是json字符串 |
|
|
|
参考: |
|
https://open.bigmodel.cn/dev/api#characterglm |
|
""" |
|
import zhipuai |
|
kwargs = cls.build_api_arguments(session_meta, history) |
|
return zhipuai.model_api.async_invoke(**kwargs, return_type='text') |
|
|
|
@classmethod |
|
def invoke(cls, session_meta: SessionMeta, history: HistoryType): |
|
""" |
|
注意: |
|
1. 先设置zhipuai.api_key |
|
2. 建议传入`return_type='text'`,否则返回结果是json字符串 |
|
3. 需要再次调用`zhipuai.model_api.query_async_invoke_result`才能获取生成结果 |
|
|
|
参考: |
|
https://open.bigmodel.cn/dev/api#characterglm |
|
""" |
|
import zhipuai |
|
kwargs = cls.build_api_arguments(session_meta, history) |
|
return zhipuai.model_api.invoke(**kwargs, return_type='text') |
|
|
|
@classmethod |
|
def generate(cls, session_meta: SessionMeta, history: HistoryType) -> str: |
|
result = cls.invoke(session_meta, history) |
|
if not result['success']: |
|
raise RuntimeError(result) |
|
return result['data']['choices'][0]['content'] |
|
|
|
@classmethod |
|
def stream_generate(cls, session_meta: SessionMeta, history: HistoryType) -> Iterator[str]: |
|
|
|
return iter(cls.generate(session_meta, history)) |
|
|