qgyd2021 commited on
Commit
d33b446
1 Parent(s): 9a65fac
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ .git/
4
+ .idea/
5
+
6
+ cache/
7
+ dotenv/
8
+ trained_models/
9
+ **/cache/
10
+ **/__pycache__/
11
+
12
+ **/*.jpg
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Dingtalk
3
- emoji: 🐢
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
 
1
  ---
2
+ title: Ding Talk
3
+ emoji: 🔥
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
dingtalk_develop.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 钉钉机器人开发
2
+
3
+ ### 操作流程
4
+
5
+ #### 创建机器人
6
+
7
+ 参考文档:
8
+ ```text
9
+ https://open.dingtalk.com/document/orgapp/create-an-application
10
+ ```
11
+
12
+ 我创建的机器人是像一个用户一样的,你可以发信息给它,然后它会回复你。
13
+
14
+ 步骤大致如下:
15
+
16
+ 首先访问 [钉钉开放平台](https://open-dev.dingtalk.com/),并登录。
17
+
18
+ 点击上方菜单栏的 [应用开发](https://open-dev.dingtalk.com/fe/app#/corp/app),
19
+ ![app_develop.jpg](docs/pictures/app_develop.jpg)
20
+
21
+ 然后在钉钉应用中点击创建应用,输入 `应用名称`,`应用描述`,`应用图标` 后保存。
22
+ ![create_app.jpg](docs/pictures/create_app.jpg)
23
+ ![create_app_save.jpg](docs/pictures/create_app_save.jpg)
24
+
25
+ 点击进入创建的应用。
26
+ ![in_app.jpg](docs/pictures/in_app.jpg)
27
+
28
+ 在 `添加应用能力` 中 找到 `其它应用能力`中点击添加 `机器人`。
29
+ 对机器人配置完成后点发布。
30
+
31
+ 在凭证与基础信息中获取 Client ID, Client Secret
32
+ ![client_info.jpg](docs/pictures/client_info.jpg)
33
+
34
+
35
+ #### 创建服务后台
36
+
37
+ 参考文档:
38
+ ```text
39
+ https://open.dingtalk.com/document/orgapp/robot-receive-message
40
+ https://github.com/open-dingtalk/dingtalk-tutorial-python
41
+ ```
42
+
43
+
44
+ ### 使用效果
45
+
46
+ 聊天示例:
47
+ ![talk_example.jpg](docs/pictures/talk_example.jpg)
examples/echo_text_bot.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+
6
+ import dingtalk_stream
7
+ from dingtalk_stream import AckMessage
8
+
9
+ from project_settings import project_path
10
+ from project_settings import environment
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--client_id",
17
+ default=environment.get("client_id"),
18
+ type=str,
19
+ )
20
+ parser.add_argument(
21
+ "--client_secret",
22
+ default=environment.get("client_secret"),
23
+ type=str,
24
+ )
25
+ args = parser.parse_args()
26
+ return args
27
+
28
+
29
+ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
30
+ def __init__(self):
31
+ super(EchoTextHandler, self).__init__()
32
+
33
+ async def process(self, callback: dingtalk_stream.CallbackMessage):
34
+ incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
35
+ text = incoming_message.text.content.strip()
36
+ self.reply_text(text, incoming_message)
37
+ return AckMessage.STATUS_OK, "OK"
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+
43
+ credential = dingtalk_stream.Credential(
44
+ client_id=args.client_id,
45
+ client_secret=args.client_secret,
46
+ )
47
+ client = dingtalk_stream.DingTalkStreamClient(credential)
48
+
49
+ client.register_callback_handler(
50
+ dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
51
+ EchoTextHandler()
52
+ )
53
+ client.start_forever()
54
+ return
55
+
56
+
57
+ if __name__ == '__main__':
58
+ main()
examples/nxlink_bot.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+
6
+ import dingtalk_stream
7
+ from dingtalk_stream import AckMessage
8
+
9
+ from project_settings import project_path
10
+ from project_settings import environment
11
+ from toolbox.agent_x.question_answer import AgentX
12
+
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "--client_id",
18
+ default=environment.get("client_id"),
19
+ type=str,
20
+ )
21
+ parser.add_argument(
22
+ "--client_secret",
23
+ default=environment.get("client_secret"),
24
+ type=str,
25
+ )
26
+ parser.add_argument(
27
+ "--agent_x_api_key",
28
+ default=environment.get("agent_x_api_key", default=None),
29
+ type=str
30
+ )
31
+ parser.add_argument(
32
+ "--agent_x_agent_name",
33
+ default=environment.get("agent_x_agent_name", default=None),
34
+ type=str
35
+ )
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
41
+ def __init__(self):
42
+ super(EchoTextHandler, self).__init__()
43
+
44
+ async def process(self, callback: dingtalk_stream.CallbackMessage):
45
+ incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
46
+ text = incoming_message.text.content.strip()
47
+ self.reply_text(text, incoming_message)
48
+ return AckMessage.STATUS_OK, "OK"
49
+
50
+
51
+ class AgentXNXLinkHandler(dingtalk_stream.ChatbotHandler):
52
+ def __init__(self, agent: AgentX):
53
+ super(AgentXNXLinkHandler, self).__init__()
54
+ self.agent = agent
55
+
56
+ async def process(self, callback: dingtalk_stream.CallbackMessage):
57
+ incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
58
+ text = incoming_message.text.content.strip()
59
+
60
+ response = self.agent.question_answer(
61
+ question=text,
62
+ streaming=True
63
+ )
64
+ answer = response["answer"]
65
+ answer = answer.split("{\"message\":")[0]
66
+
67
+ self.reply_text(answer, incoming_message)
68
+ return AckMessage.STATUS_OK, "OK"
69
+
70
+
71
+ def main():
72
+ args = get_args()
73
+
74
+ agent = AgentX(
75
+ api_key=args.agent_x_api_key,
76
+ agent_name=args.agent_x_agent_name,
77
+ )
78
+
79
+ agent_x_nxlink_handler = AgentXNXLinkHandler(agent=agent)
80
+
81
+ client = dingtalk_stream.DingTalkStreamClient(
82
+ credential=dingtalk_stream.Credential(
83
+ client_id=args.client_id,
84
+ client_secret=args.client_secret,
85
+ )
86
+ )
87
+
88
+ client.register_callback_handler(
89
+ dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
90
+ agent_x_nxlink_handler
91
+ )
92
+ client.start_forever()
93
+ return
94
+
95
+
96
+ if __name__ == '__main__':
97
+ main()
main.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import platform
7
+ import re
8
+ import string
9
+ from typing import List, Tuple
10
+
11
+ from project_settings import project_path
12
+
13
+ os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
14
+
15
+ logging.basicConfig(
16
+ level=logging.INFO if platform.system() == "Windows" else logging.INFO,
17
+ format="%(asctime)s %(levelname)s %(message)s",
18
+ datefmt="%Y-%m-%d %H:%M:%S",
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ import dingtalk_stream
24
+ from dingtalk_stream import AckMessage
25
+ import gradio as gr
26
+ from threading import Thread
27
+ import torch
28
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
29
+ from transformers.models.bert.tokenization_bert import BertTokenizer
30
+
31
+ from project_settings import environment
32
+
33
+
34
+ def get_args():
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument(
37
+ "--client_id",
38
+ default=environment.get("client_id"),
39
+ type=str,
40
+ )
41
+ parser.add_argument(
42
+ "--client_secret",
43
+ default=environment.get("client_secret"),
44
+ type=str,
45
+ )
46
+ parser.add_argument(
47
+ "--model_name",
48
+ default=(project_path / "trained_models/lib_service_4chan").as_posix() if platform.system() == "Windows" else "qgyd2021/lip_service_4chan",
49
+ type=str,
50
+ )
51
+ parser.add_argument(
52
+ "--dingtalk_develop_md_file",
53
+ default="dingtalk_develop.md",
54
+ type=str,
55
+ )
56
+ args = parser.parse_args()
57
+ return args
58
+
59
+
60
+ class LipService4ChanHandler(dingtalk_stream.ChatbotHandler):
61
+ def __init__(self,
62
+ model_name: str = "qgyd2021/lip_service_4chan",
63
+ max_input_len: int = 512,
64
+ max_new_tokens: int = 512,
65
+ top_p: float = 0.9,
66
+ temperature: float = 0.35,
67
+ repetition_penalty: float = 1.0,
68
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
69
+ ):
70
+ super(LipService4ChanHandler, self).__init__()
71
+ self.model_name = model_name
72
+ self.max_input_len = max_input_len
73
+ self.max_new_tokens = max_new_tokens
74
+ self.top_p = top_p
75
+ self.temperature = temperature
76
+ self.repetition_penalty = repetition_penalty
77
+ self.device = device
78
+
79
+ tokenizer = BertTokenizer.from_pretrained(model_name)
80
+ model = GPT2LMHeadModel.from_pretrained(model_name)
81
+ model = model.eval()
82
+ self.model = model
83
+ self.tokenizer = tokenizer
84
+
85
+ async def process(self, callback: dingtalk_stream.CallbackMessage):
86
+ incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
87
+ text = incoming_message.text.content.strip()
88
+
89
+ answer = self.get_answer(text)
90
+ self.reply_text(answer, incoming_message)
91
+
92
+ logger.info("incoming message: {}; reply text: {}".format(text, answer))
93
+
94
+ return AckMessage.STATUS_OK, "OK"
95
+
96
+ @staticmethod
97
+ def remove_space_between_cn_en(text: str):
98
+ splits = re.split(" ", text)
99
+ if len(splits) < 2:
100
+ return text
101
+
102
+ result = ""
103
+ for t in splits:
104
+ if t == "":
105
+ continue
106
+ if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t):
107
+ result += " "
108
+ result += t
109
+ else:
110
+ if not result == "":
111
+ result += t
112
+ else:
113
+ result = t
114
+
115
+ if text.endswith(" "):
116
+ result += " "
117
+ return result
118
+
119
+ def get_answer(self, text: str):
120
+ prompt_encoded = self.tokenizer.__call__(text, add_special_tokens=True)
121
+ input_ids: List[int] = prompt_encoded["input_ids"]
122
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
123
+ input_ids = input_ids[:, -self.max_input_len:]
124
+
125
+ self.tokenizer.eos_token = self.tokenizer.sep_token
126
+ self.tokenizer.eos_token_id = self.tokenizer.sep_token_id
127
+
128
+ # generate
129
+ with torch.no_grad():
130
+ outputs = self.model.generate(
131
+ input_ids=input_ids,
132
+ max_new_tokens=self.max_new_tokens,
133
+ do_sample=True,
134
+ top_p=self.top_p,
135
+ temperature=self.temperature,
136
+ repetition_penalty=self.repetition_penalty,
137
+ eos_token_id=self.tokenizer.sep_token_id,
138
+ pad_token_id=self.tokenizer.pad_token_id,
139
+ )
140
+ outputs = outputs.tolist()[0][len(input_ids[0]):]
141
+ answer = self.tokenizer.decode(outputs)
142
+ answer = answer.strip().replace(self.tokenizer.eos_token, "").strip()
143
+ answer = self.remove_space_between_cn_en(answer)
144
+
145
+ return answer
146
+
147
+
148
+ def dingtalk_server(client: dingtalk_stream.DingTalkStreamClient):
149
+ client.start_forever()
150
+
151
+
152
+ def main():
153
+ args = get_args()
154
+
155
+ # ding talk
156
+ credential = dingtalk_stream.Credential(
157
+ client_id=args.client_id,
158
+ client_secret=args.client_secret,
159
+ )
160
+ client = dingtalk_stream.DingTalkStreamClient(credential, logger)
161
+
162
+ client.register_callback_handler(
163
+ dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
164
+ LipService4ChanHandler(
165
+ model_name=args.model_name
166
+ )
167
+ )
168
+ # client.start_forever()
169
+
170
+ # background task
171
+ thread = Thread(target=dingtalk_server, kwargs={"client": client})
172
+ thread.start()
173
+
174
+ with open(args.dingtalk_develop_md_file, "r", encoding="utf-8") as f:
175
+ dingtalk_develop_md = f.read()
176
+
177
+ # ui
178
+ with gr.Blocks() as blocks:
179
+ gr.Markdown(value=dingtalk_develop_md)
180
+
181
+ blocks.queue().launch(
182
+ share=False if platform.system() == "Windows" else False,
183
+ server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
184
+ server_port=7860
185
+ )
186
+
187
+ return
188
+
189
+
190
+ if __name__ == '__main__':
191
+ main()
project_settings.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+
13
+ environment = EnvironmentManager(
14
+ path=os.path.join(project_path, "dotenv"),
15
+ env=os.environ.get("environment", "miyuki"),
16
+ )
17
+
18
+
19
+ if __name__ == '__main__':
20
+ pass
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dingtalk-stream==0.17.2
2
+ python-dotenv==1.0.0
3
+ gradio==4.12.0
4
+ transformers==4.39.1
5
+ torch==2.2.1
toolbox/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/agent_x/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/agent_x/question_answer.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import random
7
+ import sys
8
+ import time
9
+ from typing import List
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, '../../'))
13
+
14
+ import requests
15
+
16
+ from project_settings import environment
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ "--api_key",
23
+ default=environment.get("agent_x_api_key", default=None),
24
+ type=str
25
+ )
26
+
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+
31
+ class AgentX(object):
32
+ def __init__(self,
33
+ api_key: str,
34
+ agent_name: str = "NXLink智能助手",
35
+ url_host: str = "https://api.agentx.so"
36
+ ):
37
+ self.api_key = api_key
38
+ self.agent_name = agent_name
39
+ self.url_host = url_host
40
+
41
+ self.agent_id = self.get_agent_id()
42
+
43
+ def __str__(self):
44
+ result = "<{}; agent_name: {}; agent_id: {}; api_key: {}>".format(
45
+ self.__class__.__name__, self.agent_name, self.agent_id, self.api_key)
46
+ return result
47
+
48
+ def get_agent_id(self):
49
+ url = "{}/api/v1/access/agents".format(self.url_host)
50
+
51
+ headers = {
52
+ "accept": "*/*",
53
+ "x-api-key": self.api_key
54
+ }
55
+ resp = requests.request(
56
+ "GET",
57
+ url=url,
58
+ headers=headers,
59
+ )
60
+ if resp.status_code != 200:
61
+ print(resp.status_code)
62
+ print(resp.text)
63
+ exit(0)
64
+ js = resp.json()
65
+
66
+ result = None
67
+ for e in js:
68
+ if e["name"] == self.agent_name:
69
+ result = e["_id"]
70
+
71
+ if result is None:
72
+ raise AssertionError("agent not found")
73
+ return result
74
+
75
+ def get_agent_config(self):
76
+ url = "{}/api/v1/access/agents/{}".format(self.url_host, self.agent_id)
77
+
78
+ headers = {
79
+ "accept": "*/*",
80
+ "x-api-key": self.api_key
81
+ }
82
+ resp = requests.request(
83
+ "GET",
84
+ url=url,
85
+ headers=headers,
86
+ )
87
+ js = resp.json()
88
+ return js
89
+
90
+ def get_conversation_list(self):
91
+ url = "{}/api/v1/access/agents/{}/conversations".format(self.url_host, self.agent_id)
92
+
93
+ headers = {
94
+ "accept": "*/*",
95
+ "x-api-key": self.api_key
96
+ }
97
+ resp = requests.request(
98
+ "GET",
99
+ url=url,
100
+ headers=headers,
101
+ )
102
+ js = resp.json()
103
+ return js
104
+
105
+ def post_message(self, message: str, conversation_id: str, context: int = 0):
106
+ url = "{}/api/v1/access/conversations/{}/message".format(self.url_host, conversation_id)
107
+
108
+ headers = {
109
+ "accept": "*/*",
110
+ "Content-type": "application/json",
111
+ "x-api-key": self.api_key
112
+ }
113
+ data = {
114
+ "message": message,
115
+ "context": context,
116
+ }
117
+ resp = requests.request(
118
+ "POST",
119
+ url=url,
120
+ headers=headers,
121
+ data=json.dumps(data)
122
+ )
123
+ if resp.status_code != 200:
124
+ print(resp.status_code)
125
+ print(resp.text)
126
+ exit(0)
127
+ js = resp.json()
128
+ return js
129
+
130
+ def post_message_by_sse(self, message: str, conversation_id: str, context: int = 0):
131
+ url = "{}/api/v1/access/conversations/{}/messagesse".format(self.url_host, conversation_id)
132
+
133
+ headers = {
134
+ "accept": "*/*",
135
+ "Content-type": "application/json",
136
+ "x-api-key": self.api_key
137
+ }
138
+ data = {
139
+ "message": message,
140
+ "context": context,
141
+ }
142
+ resp = requests.request(
143
+ "POST",
144
+ url=url,
145
+ headers=headers,
146
+ data=json.dumps(data),
147
+ stream=True
148
+ )
149
+ # print(resp.headers)
150
+
151
+ trace_id = resp.headers["x-trace-id"]
152
+
153
+ if resp.status_code == 200:
154
+ def generator():
155
+ result = ""
156
+ buf = b""
157
+
158
+ for chunk in resp.iter_content():
159
+ buf += chunk
160
+ try:
161
+ chunk = buf.decode("utf-8")
162
+ except UnicodeDecodeError:
163
+ continue
164
+ result += chunk
165
+ buf = b""
166
+
167
+ yield chunk
168
+ return generator(), trace_id
169
+
170
+ else:
171
+ print(resp.status_code)
172
+ print(resp.headers["Content-Type"])
173
+ raise AssertionError
174
+
175
+ def get_trace_by_message_id(self, message_id: str):
176
+ url = "{}/api/v1/access/messages/{}/trace".format(self.url_host, message_id)
177
+
178
+ headers = {
179
+ "accept": "*/*",
180
+ "x-api-key": self.api_key
181
+ }
182
+ resp = requests.request(
183
+ "GET",
184
+ url=url,
185
+ headers=headers,
186
+ )
187
+
188
+ js = resp.json()
189
+ return js
190
+
191
+ def get_trace_by_trace_id(self, trace_id: str):
192
+ url = "{}/api/v1/access/traces/{}".format(self.url_host, trace_id)
193
+
194
+ headers = {
195
+ "accept": "*/*",
196
+ "x-api-key": self.api_key
197
+ }
198
+ resp = requests.request(
199
+ "GET",
200
+ url=url,
201
+ headers=headers,
202
+ )
203
+
204
+ js = resp.json()
205
+ return js
206
+
207
+ def post_new_conversation_id(self):
208
+ url = "{}/api/v1/access/agents/{}/conversations/new".format(self.url_host, self.agent_id)
209
+
210
+ headers = {
211
+ "accept": "*/*",
212
+ "x-api-key": self.api_key
213
+ }
214
+ resp = requests.request(
215
+ "POST",
216
+ url=url,
217
+ headers=headers,
218
+ )
219
+ js = resp.json()
220
+
221
+ conversation_id = js["_id"]
222
+ return conversation_id
223
+
224
+ def delete_conversation(self, conversation_id: str):
225
+ url = "{}/api/v1/access/conversations/{}".format(self.url_host, conversation_id)
226
+
227
+ headers = {
228
+ "accept": "*/*",
229
+ "Content-type": "application/json",
230
+ "x-api-key": self.api_key
231
+ }
232
+ resp = requests.request(
233
+ "DELETE",
234
+ url=url,
235
+ headers=headers,
236
+ )
237
+ js = resp.json()
238
+ return js
239
+
240
+ def update_context(self, messages: List[dict], conversation_id: str):
241
+ url = "{}/api/v1/access/conversations/{}/update-context".format(self.url_host, conversation_id)
242
+
243
+ headers = {
244
+ "accept": "*/*",
245
+ "Content-type": "application/json",
246
+ "x-api-key": self.api_key
247
+ }
248
+ data = {
249
+ "messages": messages,
250
+ }
251
+ resp = requests.request(
252
+ "PUT",
253
+ url=url,
254
+ headers=headers,
255
+ data=json.dumps(data),
256
+ )
257
+ js = resp.json()
258
+ return js
259
+
260
+ def question_answer(self, question: str, conversation_id: str = None, context: List[dict] = None, streaming: bool = False):
261
+ if conversation_id is None:
262
+ conversation_id = self.post_new_conversation_id()
263
+
264
+ if context is not None:
265
+ self.update_context(context, conversation_id)
266
+
267
+ result = {
268
+ "answer": None,
269
+ "reference": None
270
+ }
271
+ try:
272
+ if streaming:
273
+ resp_stream, trace_id = self.post_message_by_sse(question, conversation_id,
274
+ context=0 if context is None else 1)
275
+ answer = ""
276
+ for chunk in resp_stream:
277
+ print(chunk, end="")
278
+ answer += chunk
279
+ print("\n")
280
+ result["answer"] = answer
281
+ # print(answer)
282
+ # exit(0)
283
+
284
+ # [{"title": "", "source": ""}, ...]
285
+ trace = self.get_trace_by_trace_id(trace_id)
286
+
287
+ if trace == "No trace":
288
+ reference = "No trace"
289
+ else:
290
+ reference = list()
291
+ for t in trace:
292
+ reference.append((t["title"], t["source"]))
293
+
294
+ result["reference"] = reference
295
+
296
+ else:
297
+ js = self.post_message(question, conversation_id,
298
+ context=0 if context is None else 1)
299
+ answer = js["text"]
300
+ result["answer"] = answer
301
+
302
+ message_id = js["_id"]
303
+ trace = self.get_trace_by_message_id(message_id)
304
+ # print(trace)
305
+
306
+ if trace == "No trace":
307
+ reference = "No trace"
308
+ else:
309
+ reference = list()
310
+ for t in trace:
311
+ reference.append((t["title"], t["source"]))
312
+
313
+ result["reference"] = reference
314
+
315
+ finally:
316
+ self.delete_conversation(conversation_id)
317
+
318
+ return result
319
+
320
+
321
+ def main():
322
+ args = get_args()
323
+
324
+ agent = AgentX(
325
+ api_key=args.api_key,
326
+ agent_name="Yutong Bus",
327
+ )
328
+ print(agent)
329
+
330
+ context = [
331
+ {
332
+ "user": "你好"
333
+ },
334
+ {
335
+ "assistant": "你好,我们是宇通客车公司,有什么可以帮到您的吗?"
336
+ },
337
+ {
338
+ "user": "需要一辆55座客车。"
339
+ },
340
+ {
341
+ "assistant": "Which country will the bus be used in?"
342
+ },
343
+ {
344
+ "user": "你可以说中文吗。"
345
+ },
346
+ {
347
+ "assistant": "可以的,请问您需要在哪个国家使用客车?"
348
+ },
349
+ ]
350
+
351
+ question = "你好"
352
+ time_begin = time.time()
353
+ response = agent.question_answer(question, context=context, streaming=True)
354
+ time_cost = time.time() - time_begin
355
+
356
+ print(response)
357
+ print("time cost: {}".format(time_cost))
358
+ return
359
+
360
+
361
+ if __name__ == '__main__':
362
+ main()
toolbox/json/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/misc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable
4
+
5
+
6
+ def traverse(js, callback: Callable, *args, **kwargs):
7
+ if isinstance(js, list):
8
+ result = list()
9
+ for l in js:
10
+ l = traverse(l, callback, *args, **kwargs)
11
+ result.append(l)
12
+ return result
13
+ elif isinstance(js, tuple):
14
+ result = list()
15
+ for l in js:
16
+ l = traverse(l, callback, *args, **kwargs)
17
+ result.append(l)
18
+ return tuple(result)
19
+ elif isinstance(js, dict):
20
+ result = dict()
21
+ for k, v in js.items():
22
+ k = traverse(k, callback, *args, **kwargs)
23
+ v = traverse(v, callback, *args, **kwargs)
24
+ result[k] = v
25
+ return result
26
+ elif isinstance(js, int):
27
+ return callback(js, *args, **kwargs)
28
+ elif isinstance(js, str):
29
+ return callback(js, *args, **kwargs)
30
+ else:
31
+ return js
32
+
33
+
34
+ def demo1():
35
+ d = {
36
+ "env": "ppe",
37
+ "mysql_connect": {
38
+ "host": "$mysql_connect_host",
39
+ "port": 3306,
40
+ "user": "callbot",
41
+ "password": "NxcloudAI2021!",
42
+ "database": "callbot_ppe",
43
+ "charset": "utf8"
44
+ },
45
+ "es_connect": {
46
+ "hosts": ["10.20.251.8"],
47
+ "http_auth": ["elastic", "ElasticAI2021!"],
48
+ "port": 9200
49
+ }
50
+ }
51
+
52
+ def callback(s):
53
+ if isinstance(s, str) and s.startswith('$'):
54
+ return s[1:]
55
+ return s
56
+
57
+ result = traverse(d, callback=callback)
58
+ print(result)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ demo1()
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from toolbox.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))