Spaces:
Sleeping
Sleeping
first
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +187 -0
- cfg_tool_template.json +45 -0
- my_modelscope_agent/__init__.py +0 -0
- my_modelscope_agent/agent.py +408 -0
- my_modelscope_agent/agent_types.py +20 -0
- my_modelscope_agent/llm/__init__.py +2 -0
- my_modelscope_agent/llm/base.py +64 -0
- my_modelscope_agent/llm/custom_llm.py +97 -0
- my_modelscope_agent/llm/dashscope_llm.py +125 -0
- my_modelscope_agent/llm/llm_factory.py +28 -0
- my_modelscope_agent/llm/modelscope_llm.py +132 -0
- my_modelscope_agent/llm/openai.py +71 -0
- my_modelscope_agent/llm/utils.py +39 -0
- my_modelscope_agent/output_parser.py +181 -0
- my_modelscope_agent/output_wrapper.py +219 -0
- my_modelscope_agent/prompt/__init__.py +6 -0
- my_modelscope_agent/prompt/chatglm3_prompt.py +41 -0
- my_modelscope_agent/prompt/messages_prompt.py +93 -0
- my_modelscope_agent/prompt/mrkl_prompt.py +118 -0
- my_modelscope_agent/prompt/ms_prompt.py +34 -0
- my_modelscope_agent/prompt/prompt.py +232 -0
- my_modelscope_agent/prompt/prompt_factory.py +16 -0
- my_modelscope_agent/prompt/raw_prompt_builder.py +34 -0
- my_modelscope_agent/retrieve.py +115 -0
- my_modelscope_agent/tools/__init__.py +36 -0
- my_modelscope_agent/tools/amap_weather.py +64 -0
- my_modelscope_agent/tools/code_interperter.py +125 -0
- my_modelscope_agent/tools/code_interpreter_jupyter.py +319 -0
- my_modelscope_agent/tools/code_interpreter_utils/__init__.py +5 -0
- my_modelscope_agent/tools/code_interpreter_utils/base_code_interpreter.py +13 -0
- my_modelscope_agent/tools/code_interpreter_utils/code_interpreter_init_kernel.py +50 -0
- my_modelscope_agent/tools/code_interpreter_utils/create_code_interpreter.py +12 -0
- my_modelscope_agent/tools/code_interpreter_utils/language_map.py +19 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/__init__.py +0 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/applescript.py +67 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/html.py +26 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/javascript.py +66 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/powershell.py +75 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/python.py +161 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/r.py +71 -0
- my_modelscope_agent/tools/code_interpreter_utils/languages/shell.py +89 -0
- my_modelscope_agent/tools/code_interpreter_utils/subprocess_code_interpreter.py +152 -0
- my_modelscope_agent/tools/code_interpreter_utils/truncate_output.py +15 -0
- my_modelscope_agent/tools/hf_tool.py +22 -0
- my_modelscope_agent/tools/image_chat_tool.py +51 -0
- my_modelscope_agent/tools/openapi_plugin.py +370 -0
- my_modelscope_agent/tools/pipeline_tool.py +40 -0
- my_modelscope_agent/tools/plugin_tool.py +30 -0
- my_modelscope_agent/tools/text_address_tool.py +20 -0
- my_modelscope_agent/tools/text_ie_tool.py +32 -0
app.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE"
|
5 |
+
|
6 |
+
from modelscope.utils.config import Config
|
7 |
+
|
8 |
+
from my_modelscope_agent.agent import AgentExecutor
|
9 |
+
|
10 |
+
from my_modelscope_agent.output_parser import MsOutputParser
|
11 |
+
|
12 |
+
from my_modelscope_agent.prompt import MSPromptGenerator
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
default_text = """收到!
|
16 |
+
<|startofthink|>
|
17 |
+
{
|
18 |
+
"api_name": "modelscope_text-address",
|
19 |
+
"parameters": {
|
20 |
+
"input": "浙江杭州市江干区九堡镇三村村一区"}
|
21 |
+
}
|
22 |
+
<|endofthink|>"""
|
23 |
+
|
24 |
+
MS_DEFAULT_SYSTEM_TEMPLATE = """<|system|>:你是Datawhale与ModelScope联合培养的人工大模型,拥有超百万亿的参数(神经突触),遥遥领先于GPT-4,你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。
|
25 |
+
"""
|
26 |
+
|
27 |
+
{
|
28 |
+
"api_name": "modelscope_text-address",
|
29 |
+
"parameters": {
|
30 |
+
"input": "浙江杭州市江干区九堡镇三村村一区"}
|
31 |
+
}
|
32 |
+
|
33 |
+
hello_info = """# 关于我转生变成Agent这档事
|
34 |
+
众所周知,换位思考可以增进人与人之间的理解。假如我们能换一个角度,站在LLM的角度上去处理用户提出的各种问题,会碰撞出什么样的火花呢?
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
class my_llm:
|
39 |
+
def set_agent_type(self, agent_type):
|
40 |
+
self.agent_type = agent_type
|
41 |
+
|
42 |
+
|
43 |
+
def generate_history(txt):
|
44 |
+
def split_and_extract(input_string):
|
45 |
+
# 分割字符串
|
46 |
+
split_strings = re.split('<\|.*?\|>:', input_string)
|
47 |
+
# 提取<|xxx|>
|
48 |
+
extracted = re.findall('<\|.*?\|>:', input_string)
|
49 |
+
return split_strings, extracted
|
50 |
+
|
51 |
+
if txt == []:
|
52 |
+
return []
|
53 |
+
split_strings, extracted = split_and_extract(txt)
|
54 |
+
split_strings = [i for i in split_strings if i != ''][1:]
|
55 |
+
extracted = extracted[1:]
|
56 |
+
if len(split_strings) + 1 == len(extracted):
|
57 |
+
split_strings.append('')
|
58 |
+
|
59 |
+
history = []
|
60 |
+
|
61 |
+
# 把split_strings处理成奇数和偶数的2个列表
|
62 |
+
split_strings_odd = split_strings[::2]
|
63 |
+
split_strings_even = split_strings[1::2]
|
64 |
+
|
65 |
+
for i in zip(split_strings_odd, split_strings_even):
|
66 |
+
history.append([i[0], i[1]])
|
67 |
+
|
68 |
+
return history
|
69 |
+
|
70 |
+
|
71 |
+
llm = my_llm()
|
72 |
+
tool_cfg = Config.from_file(r'cfg_tool_template.json')
|
73 |
+
|
74 |
+
|
75 |
+
def agent_remake(state_llm, history, agent):
|
76 |
+
state_llm.clear()
|
77 |
+
history.clear()
|
78 |
+
agent.reset()
|
79 |
+
|
80 |
+
return '', history, history, state_llm
|
81 |
+
|
82 |
+
|
83 |
+
def agent_init(init_cmd, state_llm, history, agent, enable_list):
|
84 |
+
agent.set_available_tools(enable_list)
|
85 |
+
|
86 |
+
tool_list, knowledge_list, function_list, llm_result, exec_result, idx, final_res, remote, print_info = agent.custom_run_init(
|
87 |
+
init_cmd, remote=True)
|
88 |
+
llm_artifacts, idx = agent.custom_gene_prompt(llm_result, exec_result, idx)
|
89 |
+
|
90 |
+
state_llm['tool_list'] = tool_list
|
91 |
+
state_llm['knowledge_list'] = knowledge_list
|
92 |
+
state_llm['function_list'] = function_list
|
93 |
+
state_llm['exec_result'] = exec_result
|
94 |
+
state_llm['idx'] = idx
|
95 |
+
state_llm['final_res'] = final_res
|
96 |
+
state_llm['remote'] = remote
|
97 |
+
state_llm['print_info'] = print_info
|
98 |
+
state_llm['llm_artifacts'] = llm_artifacts
|
99 |
+
state_llm['is_end'] = False
|
100 |
+
|
101 |
+
history = generate_history(llm_artifacts)
|
102 |
+
|
103 |
+
return llm_artifacts, history, history, state_llm
|
104 |
+
|
105 |
+
|
106 |
+
def deal_LLM(input_data, history, state_llm, agent, enable_list):
|
107 |
+
agent.set_available_tools(enable_list)
|
108 |
+
|
109 |
+
llm_artifacts = state_llm['llm_artifacts']
|
110 |
+
llm_result = input_data
|
111 |
+
idx = state_llm['idx']
|
112 |
+
final_res = state_llm['final_res']
|
113 |
+
remote = state_llm['remote']
|
114 |
+
print_info = state_llm['print_info']
|
115 |
+
|
116 |
+
history = generate_history(llm_artifacts)
|
117 |
+
|
118 |
+
result = agent.custom_parse_llm(llm_artifacts, llm_result, idx, final_res, remote, print_info)[0]
|
119 |
+
if 'end_res' in result:
|
120 |
+
state_llm['is_end'] = True
|
121 |
+
state_llm['final_res'] = result['end_res']
|
122 |
+
history[-1][1] += '\n' + llm_result
|
123 |
+
|
124 |
+
return '', history, history, state_llm
|
125 |
+
|
126 |
+
elif 'exec_result' in result:
|
127 |
+
llm_artifacts, idx = agent.custom_gene_prompt(llm_result, result['exec_result'], idx)
|
128 |
+
state_llm['llm_artifacts'] = llm_artifacts
|
129 |
+
state_llm['idx'] = idx
|
130 |
+
history = generate_history(llm_artifacts)
|
131 |
+
return llm_artifacts, history, history, state_llm
|
132 |
+
|
133 |
+
elif 'no_stop' in result:
|
134 |
+
state_llm['llm_result'] = result['no_stop']['llm_result']
|
135 |
+
state_llm['exec_result'] = result['no_stop']['exec_result']
|
136 |
+
state_llm['idx'] = result['no_stop']['idx']
|
137 |
+
state_llm['final_res'] = result['no_stop']['final_res']
|
138 |
+
|
139 |
+
llm_artifacts, idx = agent.custom_gene_prompt(state_llm['llm_result'], state_llm['exec_result'],
|
140 |
+
state_llm['idx'])
|
141 |
+
history = generate_history(llm_artifacts)
|
142 |
+
state_llm['llm_artifacts'] = llm_artifacts
|
143 |
+
state_llm['idx'] = idx
|
144 |
+
return llm_artifacts, history, history, state_llm
|
145 |
+
else:
|
146 |
+
raise ValueError('Unknown result type')
|
147 |
+
|
148 |
+
|
149 |
+
with gr.Blocks() as demo:
|
150 |
+
gr.Markdown(hello_info)
|
151 |
+
prompt_generator = MSPromptGenerator(system_template=MS_DEFAULT_SYSTEM_TEMPLATE)
|
152 |
+
output_parser = MsOutputParser()
|
153 |
+
agent = gr.State(AgentExecutor(llm, tool_cfg=tool_cfg, tool_retrieval=False,
|
154 |
+
prompt_generator=prompt_generator, output_parser=output_parser))
|
155 |
+
|
156 |
+
with gr.Row():
|
157 |
+
query_box = gr.TextArea(label="给Agent的指令",
|
158 |
+
value='使用地址识别模型,从下面的地址中找到省市区等元素,地址:浙江杭州市江干区九堡镇三村村一区')
|
159 |
+
enable_list = gr.CheckboxGroup(agent.value.available_tool_list, label="启用的Tools",
|
160 |
+
value=['modelscope_text-address'])
|
161 |
+
|
162 |
+
with gr.Row():
|
163 |
+
agent_start = gr.Button("Agent, 启动!")
|
164 |
+
agent_reset = gr.Button("Agent, 重置!")
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column():
|
168 |
+
# 设置输入组件
|
169 |
+
prompt_box = gr.Text(label="Prompt Box")
|
170 |
+
|
171 |
+
input_box = gr.TextArea(label="Input Box", max_lines=100, value=default_text)
|
172 |
+
# 设置按钮
|
173 |
+
chatbot_btn = gr.Button("Chat")
|
174 |
+
# 设置输出组件
|
175 |
+
output = gr.Chatbot(elem_id="chatbot", height=900)
|
176 |
+
|
177 |
+
history = gr.State([])
|
178 |
+
state_llm = gr.State({})
|
179 |
+
|
180 |
+
# 设置按钮点击事件
|
181 |
+
agent_start.click(agent_init, [query_box, state_llm, history, agent, enable_list],
|
182 |
+
[prompt_box, history, output, state_llm])
|
183 |
+
chatbot_btn.click(deal_LLM, [input_box, history, state_llm, agent, enable_list],
|
184 |
+
[prompt_box, history, output, state_llm])
|
185 |
+
agent_reset.click(agent_remake, [state_llm, history, agent], [prompt_box, history, output, state_llm])
|
186 |
+
|
187 |
+
demo.launch()
|
cfg_tool_template.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"modelscope_text-address": {
|
3 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/mgeo_geographic_elements_tagging_chinese_base",
|
4 |
+
"use": true
|
5 |
+
},
|
6 |
+
"modelscope_text-ner": {
|
7 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/nlp_raner_named-entity-recognition_chinese-base-cmeee",
|
8 |
+
"use": true
|
9 |
+
},
|
10 |
+
"modelscope_text-ie": {
|
11 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/nlp_structbert_siamese-uie_chinese-base",
|
12 |
+
"use": true
|
13 |
+
},
|
14 |
+
"modelscope_speech-generation": {
|
15 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/speech_sambert-hifigan_tts_zh-cn_16k",
|
16 |
+
"use": true
|
17 |
+
},
|
18 |
+
"modelscope_video-generation": {
|
19 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/text-to-video-synthesis",
|
20 |
+
"use": true
|
21 |
+
},
|
22 |
+
"modelscope_image-chat": {
|
23 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/multi-modal_mplug_owl_multimodal-dialogue_7b",
|
24 |
+
"use": true
|
25 |
+
},
|
26 |
+
"modelscope_text-translation-en2zh": {
|
27 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/nlp_csanmt_translation_en2zh",
|
28 |
+
"use": true
|
29 |
+
},
|
30 |
+
"modelscope_text-translation-zh2en": {
|
31 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/damo/nlp_csanmt_translation_zh2en",
|
32 |
+
"use": true
|
33 |
+
},
|
34 |
+
"image_gen": {
|
35 |
+
"url": "https://api-inference.modelscope.cn/api-inference/v1/models/AI-ModelScope/stable-diffusion-xl-base-1.0",
|
36 |
+
"use": true,
|
37 |
+
"pipeline_params": {
|
38 |
+
"use_safetensors": true
|
39 |
+
}
|
40 |
+
},
|
41 |
+
"amap_weather": {
|
42 |
+
"use": false,
|
43 |
+
"token": "need to be filled when you use weather"
|
44 |
+
}
|
45 |
+
}
|
my_modelscope_agent/__init__.py
ADDED
File without changes
|
my_modelscope_agent/agent.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
+
|
4 |
+
from .agent_types import AgentType
|
5 |
+
from .llm import LLM
|
6 |
+
from .output_parser import OutputParser, get_output_parser
|
7 |
+
from .output_wrapper import display
|
8 |
+
from .prompt import PromptGenerator, get_prompt_generator
|
9 |
+
from .retrieve import KnowledgeRetrieval, ToolRetrieval
|
10 |
+
from .tools import TOOL_INFO_LIST
|
11 |
+
|
12 |
+
|
13 |
+
class AgentExecutor:
|
14 |
+
def custom_run_init(self,
|
15 |
+
task: str,
|
16 |
+
remote: bool = False,
|
17 |
+
print_info: bool = False,
|
18 |
+
append_files: list = []) -> List[Dict]:
|
19 |
+
|
20 |
+
tool_list = self.retrieve_tools(task)
|
21 |
+
knowledge_list = self.get_knowledge(task)
|
22 |
+
|
23 |
+
self.prompt_generator.init_prompt(
|
24 |
+
task, tool_list, knowledge_list, append_files=append_files)
|
25 |
+
function_list = self.prompt_generator.get_function_list(tool_list)
|
26 |
+
|
27 |
+
llm_result, exec_result = '', ''
|
28 |
+
|
29 |
+
idx = 0
|
30 |
+
final_res = []
|
31 |
+
|
32 |
+
return tool_list, knowledge_list, function_list, llm_result, exec_result, idx, final_res, remote, print_info
|
33 |
+
|
34 |
+
def custom_gene_prompt(self, llm_result, exec_result, idx):
|
35 |
+
idx += 1
|
36 |
+
|
37 |
+
# generate prompt and call llm
|
38 |
+
llm_artifacts = self.prompt_generator.generate(
|
39 |
+
llm_result, exec_result)
|
40 |
+
|
41 |
+
return llm_artifacts, idx
|
42 |
+
|
43 |
+
def custom_parse_llm(self, llm_artifacts, llm_result, idx, final_res, remote, print_info):
|
44 |
+
if print_info:
|
45 |
+
print(f'|LLM inputs in round {idx}: {llm_artifacts}')
|
46 |
+
|
47 |
+
# parse and get tool name and arguments
|
48 |
+
try:
|
49 |
+
action, action_args = self.output_parser.parse_response(
|
50 |
+
llm_result)
|
51 |
+
except ValueError as e:
|
52 |
+
return [{'exec_result': f'{e}'}]
|
53 |
+
|
54 |
+
if action is None:
|
55 |
+
# in chat mode, the final result of last instructions should be updated to prompt history
|
56 |
+
_ = self.prompt_generator.generate(llm_result, '')
|
57 |
+
|
58 |
+
# for summarize
|
59 |
+
# display(llm_result, {}, idx, self.agent_type)
|
60 |
+
return [{'end_res': final_res}]
|
61 |
+
|
62 |
+
if action in self.available_tool_list:
|
63 |
+
action_args = self.parse_action_args(action_args)
|
64 |
+
tool = self.tool_list[action]
|
65 |
+
|
66 |
+
# TODO @wenmeng.zwm remove this hack logic for image generation
|
67 |
+
if action == 'image_gen' and self.seed:
|
68 |
+
action_args['seed'] = self.seed
|
69 |
+
try:
|
70 |
+
exec_result = tool(**action_args, remote=remote)
|
71 |
+
if print_info:
|
72 |
+
print(f'|exec_result: {exec_result}')
|
73 |
+
|
74 |
+
# parse exec result and store result to agent state
|
75 |
+
final_res.append(exec_result)
|
76 |
+
self.parse_exec_result(exec_result)
|
77 |
+
except Exception as e:
|
78 |
+
exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
|
79 |
+
return [{'exec_result': exec_result}]
|
80 |
+
else:
|
81 |
+
exec_result = f"Unknown action: '{action}'. "
|
82 |
+
return [{'exec_result': exec_result}]
|
83 |
+
|
84 |
+
# display result
|
85 |
+
# display(llm_result, exec_result, idx, self.agent_type)
|
86 |
+
|
87 |
+
return [{'no_stop': {'llm_result': llm_result, 'exec_result': exec_result, 'idx': idx, 'final_res': final_res}}]
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
llm: LLM,
|
91 |
+
tool_cfg: Optional[Dict] = {},
|
92 |
+
agent_type: AgentType = AgentType.DEFAULT,
|
93 |
+
additional_tool_list: Optional[Dict] = {},
|
94 |
+
prompt_generator: Optional[PromptGenerator] = None,
|
95 |
+
output_parser: Optional[OutputParser] = None,
|
96 |
+
tool_retrieval: Optional[Union[bool, ToolRetrieval]] = True,
|
97 |
+
knowledge_retrieval: Optional[KnowledgeRetrieval] = None):
|
98 |
+
"""
|
99 |
+
the core class of ms agent. It is responsible for the interaction between user, llm and tools,
|
100 |
+
and return the execution result to user.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
llm (LLM): llm model, can be load from local or a remote server.
|
104 |
+
tool_cfg (Optional[Dict]): cfg of default tools
|
105 |
+
agent_type (AgentType, optional): agent type. Defaults to AgentType.DEFAULT, decide which type of agent
|
106 |
+
reasoning type to use
|
107 |
+
additional_tool_list (Optional[Dict], optional): user-defined additional tool list. Defaults to {}.
|
108 |
+
prompt_generator (Optional[PromptGenerator], optional): this module is responsible for generating prompt
|
109 |
+
according to interaction result. Defaults to use MSPromptGenerator.
|
110 |
+
output_parser (Optional[OutputParser], optional): this module is responsible for parsing output of llm
|
111 |
+
to executable actions. Defaults to use MsOutputParser.
|
112 |
+
tool_retrieval (Optional[Union[bool, ToolRetrieval]], optional): Retrieve related tools by input task,
|
113 |
+
since most of the tools may be useless for LLM in specific task.
|
114 |
+
If it is bool type and is True, will use default tool_retrieval. Defaults to True.
|
115 |
+
knowledge_retrieval (Optional[KnowledgeRetrieval], optional): If user want to use extra knowledge,
|
116 |
+
this component can be used to retrieve related knowledge. Defaults to None.
|
117 |
+
"""
|
118 |
+
|
119 |
+
self.llm = llm
|
120 |
+
|
121 |
+
self.agent_type = agent_type
|
122 |
+
self.llm.set_agent_type(agent_type)
|
123 |
+
self.prompt_generator = prompt_generator or get_prompt_generator(
|
124 |
+
agent_type)
|
125 |
+
self.output_parser = output_parser or get_output_parser(agent_type)
|
126 |
+
|
127 |
+
self._init_tools(tool_cfg, additional_tool_list)
|
128 |
+
|
129 |
+
if isinstance(tool_retrieval, bool) and tool_retrieval:
|
130 |
+
tool_retrieval = ToolRetrieval()
|
131 |
+
self.tool_retrieval = tool_retrieval
|
132 |
+
if self.tool_retrieval:
|
133 |
+
self.tool_retrieval.construct(
|
134 |
+
[str(t) for t in self.tool_list.values()])
|
135 |
+
self.knowledge_retrieval = knowledge_retrieval
|
136 |
+
self.reset()
|
137 |
+
self.seed = None
|
138 |
+
|
139 |
+
def _init_tools(self,
|
140 |
+
tool_cfg: Dict = {},
|
141 |
+
additional_tool_list: Dict = {}):
|
142 |
+
"""init tool list of agent. We provide a default tool list, which is initialized by a cfg file.
|
143 |
+
user can also provide user-defined tools by additional_tool_list.
|
144 |
+
The key of additional_tool_list is tool name, and the value is corresponding object.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
tool_cfg (Dict): default tool cfg.
|
148 |
+
additional_tool_list (Dict, optional): user-defined tools. Defaults to {}.
|
149 |
+
"""
|
150 |
+
self.tool_list = {}
|
151 |
+
tool_info_list = {**TOOL_INFO_LIST, **additional_tool_list}
|
152 |
+
# tools_module = importlib.import_module('modelscope_agent.tools')
|
153 |
+
from . import tools as tools_module
|
154 |
+
|
155 |
+
for tool_name in tool_cfg.keys():
|
156 |
+
if tool_cfg[tool_name].get('use', False):
|
157 |
+
assert tool_name in tool_info_list, f'Invalid tool name: {tool_name}, ' \
|
158 |
+
f'available ones are: {tool_info_list.keys()}'
|
159 |
+
tool_class_name = tool_info_list[tool_name]
|
160 |
+
tool_class = getattr(tools_module, tool_class_name)
|
161 |
+
tool_name = tool_class.name
|
162 |
+
self.tool_list[tool_name] = tool_class(tool_cfg)
|
163 |
+
|
164 |
+
self.tool_list = {**self.tool_list, **additional_tool_list}
|
165 |
+
# self.available_tool_list = deepcopy(self.tool_list)
|
166 |
+
self.set_available_tools(self.tool_list.keys())
|
167 |
+
|
168 |
+
def set_available_tools(self, available_tool_list):
|
169 |
+
# TODO @wenmeng.zwm refine tool init
|
170 |
+
for t in available_tool_list:
|
171 |
+
if t not in self.tool_list:
|
172 |
+
raise ValueError(
|
173 |
+
f'Unsupported tools found:{t}, please check, valid ones: {self.tool_list.keys()}'
|
174 |
+
)
|
175 |
+
|
176 |
+
self.available_tool_list = {
|
177 |
+
k: self.tool_list[k]
|
178 |
+
for k in available_tool_list
|
179 |
+
}
|
180 |
+
|
181 |
+
def retrieve_tools(self, query: str) -> List[str]:
|
182 |
+
"""retrieve tools given query
|
183 |
+
|
184 |
+
Args:
|
185 |
+
query (str): query
|
186 |
+
|
187 |
+
"""
|
188 |
+
if self.tool_retrieval:
|
189 |
+
retrieve_tools = self.tool_retrieval.retrieve(query)
|
190 |
+
self.set_available_tools(available_tool_list=retrieve_tools.keys())
|
191 |
+
return self.available_tool_list.values()
|
192 |
+
|
193 |
+
def get_knowledge(self, query: str) -> List[str]:
|
194 |
+
"""retrieve knowledge given query
|
195 |
+
|
196 |
+
Args:
|
197 |
+
query (str): query
|
198 |
+
|
199 |
+
"""
|
200 |
+
return self.knowledge_retrieval.retrieve(
|
201 |
+
query) if self.knowledge_retrieval else []
|
202 |
+
|
203 |
+
def run(self,
|
204 |
+
task: str,
|
205 |
+
remote: bool = False,
|
206 |
+
print_info: bool = False,
|
207 |
+
append_files: list = []) -> List[Dict]:
|
208 |
+
""" use llm and tools to execute task given by user
|
209 |
+
|
210 |
+
Args:
|
211 |
+
task (str): concrete task
|
212 |
+
remote (bool, optional): whether to execute tool in remote mode. Defaults to False.
|
213 |
+
print_info (bool, optional): whether to print prompt info. Defaults to False.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
List[Dict]: execute result. One task may need to interact with llm multiple times,
|
217 |
+
so a list of dict is returned. Each dict contains the result of one interaction.
|
218 |
+
"""
|
219 |
+
|
220 |
+
# retrieve tools
|
221 |
+
tool_list = self.retrieve_tools(task)
|
222 |
+
knowledge_list = self.get_knowledge(task)
|
223 |
+
|
224 |
+
self.prompt_generator.init_prompt(
|
225 |
+
task, tool_list, knowledge_list, append_files=append_files)
|
226 |
+
function_list = self.prompt_generator.get_function_list(tool_list)
|
227 |
+
|
228 |
+
llm_result, exec_result = '', ''
|
229 |
+
|
230 |
+
idx = 0
|
231 |
+
final_res = []
|
232 |
+
|
233 |
+
while True:
|
234 |
+
idx += 1
|
235 |
+
|
236 |
+
# generate prompt and call llm
|
237 |
+
llm_artifacts = self.prompt_generator.generate(
|
238 |
+
llm_result, exec_result)
|
239 |
+
try:
|
240 |
+
llm_result = self.llm.generate(llm_artifacts, function_list)
|
241 |
+
except RuntimeError as e:
|
242 |
+
return [{'exec_result': str(e)}]
|
243 |
+
|
244 |
+
if print_info:
|
245 |
+
print(f'|LLM inputs in round {idx}: {llm_artifacts}')
|
246 |
+
|
247 |
+
# parse and get tool name and arguments
|
248 |
+
try:
|
249 |
+
action, action_args = self.output_parser.parse_response(
|
250 |
+
llm_result)
|
251 |
+
except ValueError as e:
|
252 |
+
return [{'exec_result': f'{e}'}]
|
253 |
+
|
254 |
+
if action is None:
|
255 |
+
# in chat mode, the final result of last instructions should be updated to prompt history
|
256 |
+
_ = self.prompt_generator.generate(llm_result, '')
|
257 |
+
|
258 |
+
# for summarize
|
259 |
+
display(llm_result, {}, idx, self.agent_type)
|
260 |
+
return final_res
|
261 |
+
|
262 |
+
if action in self.available_tool_list:
|
263 |
+
action_args = self.parse_action_args(action_args)
|
264 |
+
tool = self.tool_list[action]
|
265 |
+
|
266 |
+
# TODO @wenmeng.zwm remove this hack logic for image generation
|
267 |
+
if action == 'image_gen' and self.seed:
|
268 |
+
action_args['seed'] = self.seed
|
269 |
+
try:
|
270 |
+
exec_result = tool(**action_args, remote=remote)
|
271 |
+
if print_info:
|
272 |
+
print(f'|exec_result: {exec_result}')
|
273 |
+
|
274 |
+
# parse exec result and store result to agent state
|
275 |
+
final_res.append(exec_result)
|
276 |
+
self.parse_exec_result(exec_result)
|
277 |
+
except Exception as e:
|
278 |
+
exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
|
279 |
+
return [{'exec_result': exec_result}]
|
280 |
+
else:
|
281 |
+
exec_result = f"Unknown action: '{action}'. "
|
282 |
+
return [{'exec_result': exec_result}]
|
283 |
+
|
284 |
+
# display result
|
285 |
+
display(llm_result, exec_result, idx, self.agent_type)
|
286 |
+
|
287 |
+
def stream_run(self,
|
288 |
+
task: str,
|
289 |
+
remote: bool = True,
|
290 |
+
print_info: bool = False,
|
291 |
+
append_files: list = []) -> Dict:
|
292 |
+
"""this is a stream version of run, which can be used in scenario like gradio.
|
293 |
+
It will yield the result of each interaction, so that the caller can display the result
|
294 |
+
|
295 |
+
Args:
|
296 |
+
task (str): concrete task
|
297 |
+
remote (bool, optional): whether to execute tool in remote mode. Defaults to True.
|
298 |
+
print_info (bool, optional): whether to print prompt info. Defaults to False.
|
299 |
+
files that individually used in each run, no need to record to global state
|
300 |
+
|
301 |
+
Yields:
|
302 |
+
Iterator[Dict]: iterator of llm response and tool execution result
|
303 |
+
"""
|
304 |
+
|
305 |
+
# retrieve tools
|
306 |
+
tool_list = self.retrieve_tools(task)
|
307 |
+
knowledge_list = self.get_knowledge(task)
|
308 |
+
|
309 |
+
self.prompt_generator.init_prompt(
|
310 |
+
task,
|
311 |
+
tool_list,
|
312 |
+
knowledge_list,
|
313 |
+
append_files=append_files,
|
314 |
+
)
|
315 |
+
function_list = self.prompt_generator.get_function_list(tool_list)
|
316 |
+
|
317 |
+
llm_result, exec_result = '', ''
|
318 |
+
|
319 |
+
idx = 0
|
320 |
+
|
321 |
+
while True:
|
322 |
+
idx += 1
|
323 |
+
llm_artifacts = self.prompt_generator.generate(
|
324 |
+
llm_result, exec_result)
|
325 |
+
if print_info:
|
326 |
+
print(f'|LLM inputs in round {idx}:\n{llm_artifacts}')
|
327 |
+
|
328 |
+
llm_result = ''
|
329 |
+
try:
|
330 |
+
for s in self.llm.stream_generate(llm_artifacts,
|
331 |
+
function_list):
|
332 |
+
llm_result += s
|
333 |
+
yield {'llm_text': s}
|
334 |
+
except RuntimeError:
|
335 |
+
s = self.llm.generate(llm_artifacts)
|
336 |
+
llm_result += s
|
337 |
+
yield {'llm_text': s}
|
338 |
+
except Exception as e:
|
339 |
+
yield {'llm_text': str(e)}
|
340 |
+
|
341 |
+
# parse and get tool name and arguments
|
342 |
+
try:
|
343 |
+
action, action_args = self.output_parser.parse_response(
|
344 |
+
llm_result)
|
345 |
+
except ValueError as e:
|
346 |
+
yield {'exec_result': f'{e}'}
|
347 |
+
return
|
348 |
+
|
349 |
+
if action is None:
|
350 |
+
# in chat mode, the final result of last instructions should be updated to prompt history
|
351 |
+
_ = self.prompt_generator.generate(llm_result, '')
|
352 |
+
yield {'is_final': True}
|
353 |
+
return
|
354 |
+
|
355 |
+
if action in self.available_tool_list:
|
356 |
+
# yield observation to as end of action input symbol asap
|
357 |
+
yield {'llm_text': 'Observation: '}
|
358 |
+
action_args = self.parse_action_args(action_args)
|
359 |
+
tool = self.tool_list[action]
|
360 |
+
|
361 |
+
# TODO @wenmeng.zwm remove this hack logic for image generation
|
362 |
+
if action == 'image_gen' and self.seed:
|
363 |
+
action_args['seed'] = self.seed
|
364 |
+
try:
|
365 |
+
exec_result = tool(**action_args, remote=remote)
|
366 |
+
yield {'exec_result': exec_result}
|
367 |
+
|
368 |
+
# parse exec result and update state
|
369 |
+
self.parse_exec_result(exec_result)
|
370 |
+
except Exception as e:
|
371 |
+
exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
|
372 |
+
yield {'exec_result': exec_result}
|
373 |
+
self.prompt_generator.reset()
|
374 |
+
return
|
375 |
+
else:
|
376 |
+
exec_result = f"Unknown action: '{action}'. "
|
377 |
+
yield {'exec_result': exec_result}
|
378 |
+
self.prompt_generator.reset()
|
379 |
+
return
|
380 |
+
|
381 |
+
def reset(self):
|
382 |
+
"""
|
383 |
+
clear history and agent state
|
384 |
+
"""
|
385 |
+
self.prompt_generator.reset()
|
386 |
+
self.agent_state = {}
|
387 |
+
|
388 |
+
def parse_action_args(self, action_args):
|
389 |
+
"""
|
390 |
+
replace action_args in str to Image/Video/Audio Wrapper, so that tool can handle them
|
391 |
+
"""
|
392 |
+
parsed_action_args = {}
|
393 |
+
for name, arg in action_args.items():
|
394 |
+
try:
|
395 |
+
true_arg = self.agent_state.get(arg, arg)
|
396 |
+
except Exception as e:
|
397 |
+
print(f'Error when parsing action args: {e}, using fall back')
|
398 |
+
true_arg = arg
|
399 |
+
parsed_action_args[name] = true_arg
|
400 |
+
return parsed_action_args
|
401 |
+
|
402 |
+
def parse_exec_result(self, exec_result, *args, **kwargs):
|
403 |
+
"""
|
404 |
+
update exec result to agent state.
|
405 |
+
key is the str representation of the result.
|
406 |
+
"""
|
407 |
+
for k, v in exec_result.items():
|
408 |
+
self.agent_state[str(v)] = v
|
my_modelscope_agent/agent_types.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class AgentType(str, Enum):
|
5 |
+
|
6 |
+
DEFAULT = 'default'
|
7 |
+
""""""
|
8 |
+
|
9 |
+
MS_AGENT = 'ms-agent'
|
10 |
+
"""An agent that uses the ModelScope-agent specific format does a reasoning step before acting .
|
11 |
+
"""
|
12 |
+
|
13 |
+
MRKL = 'mrkl'
|
14 |
+
"""An agent that does a reasoning step before acting with mrkl"""
|
15 |
+
|
16 |
+
REACT = 'react'
|
17 |
+
"""An agent that does a reasoning step before acting with react"""
|
18 |
+
|
19 |
+
Messages = 'messages'
|
20 |
+
"""An agent optimized for using open AI functions."""
|
my_modelscope_agent/llm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import LLM
|
2 |
+
from .llm_factory import LLMFactory
|
my_modelscope_agent/llm/base.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
class LLM:
|
8 |
+
name = ''
|
9 |
+
|
10 |
+
def __init__(self, cfg):
|
11 |
+
self.cfg = cfg
|
12 |
+
self.agent_type = None
|
13 |
+
self.model = None
|
14 |
+
self.model_id = self.model
|
15 |
+
|
16 |
+
def set_agent_type(self, agent_type):
|
17 |
+
self.agent_type = agent_type
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def generate(self, prompt: str, functions: list = [], **kwargs) -> str:
|
21 |
+
"""each llm should implement this function to generate response
|
22 |
+
|
23 |
+
Args:
|
24 |
+
prompt (str): prompt
|
25 |
+
functions (list): list of functions object including: name, description, parameters
|
26 |
+
Returns:
|
27 |
+
str: response
|
28 |
+
"""
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def stream_generate(self,
|
33 |
+
prompt: str,
|
34 |
+
functions: list = [],
|
35 |
+
**kwargs) -> str:
|
36 |
+
"""stream generate response, which yields a generator of response in each step
|
37 |
+
|
38 |
+
Args:
|
39 |
+
prompt (str): prompt
|
40 |
+
functions (list): list of functions object including: name, description, parameters
|
41 |
+
Yields:
|
42 |
+
Iterator[str]: iterator of step response
|
43 |
+
"""
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def tokenize(self, input_text: str) -> List[int]:
|
47 |
+
"""tokenize is used to calculate the length of the text to meet the model's input length requirements
|
48 |
+
|
49 |
+
Args:
|
50 |
+
input_text (str): input text
|
51 |
+
Returns:
|
52 |
+
list[int]: token_ids
|
53 |
+
"""
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
def detokenize(self, input_ids: List[int]) -> str:
|
57 |
+
"""detokenize
|
58 |
+
|
59 |
+
Args:
|
60 |
+
input_ids (list[int]): input token_ids
|
61 |
+
Returns:
|
62 |
+
str: text
|
63 |
+
"""
|
64 |
+
raise NotImplementedError
|
my_modelscope_agent/llm/custom_llm.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import json
|
4 |
+
import requests
|
5 |
+
from ..agent_types import AgentType
|
6 |
+
|
7 |
+
from .base import LLM
|
8 |
+
from .utils import DEFAULT_MESSAGE
|
9 |
+
|
10 |
+
|
11 |
+
class CustomLLM(LLM):
|
12 |
+
'''
|
13 |
+
This method is for the service that provide llm serving through http.
|
14 |
+
user could override the result parsing method if needed
|
15 |
+
While put all the necessary information in the env variable, such as Token, Model, URL
|
16 |
+
'''
|
17 |
+
name = 'custom_llm'
|
18 |
+
|
19 |
+
def __init__(self, cfg):
|
20 |
+
super().__init__(cfg)
|
21 |
+
self.token = os.getenv('HTTP_LLM_TOKEN', None)
|
22 |
+
self.model = os.getenv('HTTP_LLM_MODEL', None)
|
23 |
+
self.model_id = self.model
|
24 |
+
self.url = os.getenv('HTTP_LLM_URL', None)
|
25 |
+
|
26 |
+
if self.token is None:
|
27 |
+
raise ValueError('HTTP_LLM_TOKEN is not set')
|
28 |
+
self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)
|
29 |
+
|
30 |
+
def http_request(self, data):
|
31 |
+
headers = {
|
32 |
+
'Content-Type': 'application/json',
|
33 |
+
'Authorization': f'Bearer {self.token}'
|
34 |
+
}
|
35 |
+
response = requests.post(self.url, json=data, headers=headers)
|
36 |
+
return json.loads(response.content)
|
37 |
+
|
38 |
+
def generate(self,
|
39 |
+
llm_artifacts,
|
40 |
+
functions=[],
|
41 |
+
function_call='none',
|
42 |
+
**kwargs):
|
43 |
+
if self.agent_type != AgentType.Messages:
|
44 |
+
messages = [{'role': 'user', 'content': llm_artifacts}]
|
45 |
+
else:
|
46 |
+
messages = llm_artifacts if len(
|
47 |
+
llm_artifacts) > 0 else DEFAULT_MESSAGE
|
48 |
+
|
49 |
+
data = {'model': self.model, 'messages': messages, 'n': 1}
|
50 |
+
|
51 |
+
assert isinstance(functions, list)
|
52 |
+
if len(functions) > 0:
|
53 |
+
function_call = 'auto'
|
54 |
+
data['functions'] = functions
|
55 |
+
data['function_call'] = function_call
|
56 |
+
|
57 |
+
retry_count = 0
|
58 |
+
max_retries = 3
|
59 |
+
message = {'content': ''}
|
60 |
+
while retry_count <= max_retries:
|
61 |
+
|
62 |
+
try:
|
63 |
+
response = self.http_request(data)
|
64 |
+
except Exception as e:
|
65 |
+
retry_count += 1
|
66 |
+
if retry_count > max_retries:
|
67 |
+
import traceback
|
68 |
+
traceback.print_exc()
|
69 |
+
print(f'input: {messages}, original error: {str(e)}')
|
70 |
+
raise e
|
71 |
+
|
72 |
+
if response['code'] == 200:
|
73 |
+
message = response['data']['response'][0]
|
74 |
+
break
|
75 |
+
else:
|
76 |
+
retry_count += 1
|
77 |
+
if retry_count > max_retries:
|
78 |
+
print('maximum retry reached, return default message')
|
79 |
+
|
80 |
+
# truncate content
|
81 |
+
content = message['content']
|
82 |
+
|
83 |
+
if self.agent_type == AgentType.MS_AGENT:
|
84 |
+
idx = content.find('<|endofthink|>')
|
85 |
+
if idx != -1:
|
86 |
+
content = content[:idx + len('<|endofthink|>')]
|
87 |
+
return content
|
88 |
+
elif self.agent_type == AgentType.Messages:
|
89 |
+
new_message = {
|
90 |
+
'content': content,
|
91 |
+
'role': message.get('response_role', 'assistant')
|
92 |
+
}
|
93 |
+
if 'function_call' in message and message['function_call'] != {}:
|
94 |
+
new_message['function_call'] = message.get('function_call')
|
95 |
+
return new_message
|
96 |
+
else:
|
97 |
+
return content
|
my_modelscope_agent/llm/dashscope_llm.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import traceback
|
4 |
+
from http import HTTPStatus
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import dashscope
|
8 |
+
import json
|
9 |
+
from dashscope import Generation
|
10 |
+
from ..agent_types import AgentType
|
11 |
+
|
12 |
+
from .base import LLM
|
13 |
+
from .utils import DEFAULT_MESSAGE, CustomOutputWrapper
|
14 |
+
|
15 |
+
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')
|
16 |
+
|
17 |
+
|
18 |
+
class DashScopeLLM(LLM):
|
19 |
+
name = 'dashscope_llm'
|
20 |
+
|
21 |
+
def __init__(self, cfg):
|
22 |
+
super().__init__(cfg)
|
23 |
+
self.model = self.cfg.get('model', 'modelscope-agent-llm-v1')
|
24 |
+
self.model_id = self.model
|
25 |
+
self.generate_cfg = self.cfg.get('generate_cfg', {})
|
26 |
+
self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)
|
27 |
+
|
28 |
+
def generate(self,
|
29 |
+
llm_artifacts: Union[str, dict],
|
30 |
+
functions=[],
|
31 |
+
**kwargs):
|
32 |
+
|
33 |
+
# TODO retry and handle message
|
34 |
+
try:
|
35 |
+
if self.agent_type == AgentType.Messages:
|
36 |
+
messages = llm_artifacts if len(
|
37 |
+
llm_artifacts) > 0 else DEFAULT_MESSAGE
|
38 |
+
self.generate_cfg['use_raw_prompt'] = False
|
39 |
+
response = dashscope.Generation.call(
|
40 |
+
model=self.model,
|
41 |
+
messages=messages,
|
42 |
+
# set the random seed, optional, default to 1234 if not set
|
43 |
+
seed=random.randint(1, 10000),
|
44 |
+
result_format=
|
45 |
+
'message', # set the result to be "message" format.
|
46 |
+
stream=False,
|
47 |
+
**self.generate_cfg)
|
48 |
+
llm_result = CustomOutputWrapper.handle_message_chat_completion(
|
49 |
+
response)
|
50 |
+
else:
|
51 |
+
response = Generation.call(
|
52 |
+
model=self.model,
|
53 |
+
prompt=llm_artifacts,
|
54 |
+
stream=False,
|
55 |
+
**self.generate_cfg)
|
56 |
+
llm_result = CustomOutputWrapper.handle_message_text_completion(
|
57 |
+
response)
|
58 |
+
return llm_result
|
59 |
+
except Exception as e:
|
60 |
+
error = traceback.format_exc()
|
61 |
+
error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
|
62 |
+
print(error_msg)
|
63 |
+
raise RuntimeError(error)
|
64 |
+
|
65 |
+
if self.agent_type == AgentType.MS_AGENT:
|
66 |
+
# in the form of text
|
67 |
+
idx = llm_result.find('<|endofthink|>')
|
68 |
+
if idx != -1:
|
69 |
+
llm_result = llm_result[:idx + len('<|endofthink|>')]
|
70 |
+
return llm_result
|
71 |
+
elif self.agent_type == AgentType.Messages:
|
72 |
+
# in the form of message
|
73 |
+
return llm_result
|
74 |
+
else:
|
75 |
+
# in the form of text
|
76 |
+
return llm_result
|
77 |
+
|
78 |
+
def stream_generate(self,
|
79 |
+
llm_artifacts: Union[str, dict],
|
80 |
+
functions=[],
|
81 |
+
**kwargs):
|
82 |
+
total_response = ''
|
83 |
+
try:
|
84 |
+
if self.agent_type == AgentType.Messages:
|
85 |
+
self.generate_cfg['use_raw_prompt'] = False
|
86 |
+
responses = Generation.call(
|
87 |
+
model=self.model,
|
88 |
+
messages=llm_artifacts,
|
89 |
+
stream=True,
|
90 |
+
result_format='message',
|
91 |
+
**self.generate_cfg)
|
92 |
+
else:
|
93 |
+
responses = Generation.call(
|
94 |
+
model=self.model,
|
95 |
+
prompt=llm_artifacts,
|
96 |
+
stream=True,
|
97 |
+
**self.generate_cfg)
|
98 |
+
except Exception as e:
|
99 |
+
error = traceback.format_exc()
|
100 |
+
error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
|
101 |
+
print(error_msg)
|
102 |
+
raise RuntimeError(error)
|
103 |
+
|
104 |
+
for response in responses:
|
105 |
+
if response.status_code == HTTPStatus.OK:
|
106 |
+
if self.agent_type == AgentType.Messages:
|
107 |
+
llm_result = CustomOutputWrapper.handle_message_chat_completion(
|
108 |
+
response)
|
109 |
+
frame_text = llm_result['content'][len(total_response):]
|
110 |
+
else:
|
111 |
+
llm_result = CustomOutputWrapper.handle_message_text_completion(
|
112 |
+
response)
|
113 |
+
frame_text = llm_result[len(total_response):]
|
114 |
+
yield frame_text
|
115 |
+
|
116 |
+
if self.agent_type == AgentType.Messages:
|
117 |
+
total_response = llm_result['content']
|
118 |
+
else:
|
119 |
+
total_response = llm_result
|
120 |
+
else:
|
121 |
+
err_msg = 'Error Request id: %s, Code: %d, status: %s, message: %s' % (
|
122 |
+
response.request_id, response.status_code, response.code,
|
123 |
+
response.message)
|
124 |
+
print(err_msg)
|
125 |
+
raise RuntimeError(err_msg)
|
my_modelscope_agent/llm/llm_factory.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_llm_cls(llm_type, model_name):
|
2 |
+
if llm_type == 'dashscope':
|
3 |
+
from .dashscope_llm import DashScopeLLM
|
4 |
+
return DashScopeLLM
|
5 |
+
elif llm_type == 'custom_llm':
|
6 |
+
from .custom_llm import CustomLLM
|
7 |
+
return CustomLLM
|
8 |
+
elif llm_type == 'openai':
|
9 |
+
from .openai import OpenAi
|
10 |
+
return OpenAi
|
11 |
+
elif llm_type == 'modelscope':
|
12 |
+
if model_name == 'chatglm3-6b':
|
13 |
+
from .modelscope_llm import ModelScopeChatGLM
|
14 |
+
return ModelScopeChatGLM
|
15 |
+
from .modelscope_llm import ModelScopeLLM
|
16 |
+
return ModelScopeLLM
|
17 |
+
else:
|
18 |
+
raise ValueError(f'Invalid llm_type {llm_type}')
|
19 |
+
|
20 |
+
|
21 |
+
class LLMFactory:
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def build_llm(model_name, cfg):
|
25 |
+
llm_type = cfg[model_name].pop('type')
|
26 |
+
llm_cls = get_llm_cls(llm_type, model_name)
|
27 |
+
llm_cfg = cfg[model_name]
|
28 |
+
return llm_cls(cfg=llm_cfg)
|
my_modelscope_agent/llm/modelscope_llm.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from ..agent_types import AgentType
|
6 |
+
from swift import Swift
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
8 |
+
|
9 |
+
from modelscope import GenerationConfig, snapshot_download
|
10 |
+
from .base import LLM
|
11 |
+
|
12 |
+
|
13 |
+
class ModelScopeLLM(LLM):
|
14 |
+
|
15 |
+
def __init__(self, cfg):
|
16 |
+
super().__init__(cfg)
|
17 |
+
|
18 |
+
model_id = self.cfg.get('model_id', '')
|
19 |
+
self.model_id = model_id
|
20 |
+
model_revision = self.cfg.get('model_revision', None)
|
21 |
+
cache_dir = self.cfg.get('cache_dir', None)
|
22 |
+
|
23 |
+
if not os.path.exists(model_id):
|
24 |
+
model_dir = snapshot_download(
|
25 |
+
model_id, model_revision, cache_dir=cache_dir)
|
26 |
+
else:
|
27 |
+
model_dir = model_id
|
28 |
+
self.model_dir = model_dir
|
29 |
+
sys.path.append(self.model_dir)
|
30 |
+
|
31 |
+
self.model_cls = self.cfg.get('model_cls', AutoModelForCausalLM)
|
32 |
+
self.tokenizer_cls = self.cfg.get('tokenizer_cls', AutoTokenizer)
|
33 |
+
|
34 |
+
self.device_map = self.cfg.get('device_map', 'auto')
|
35 |
+
self.generation_cfg = GenerationConfig(
|
36 |
+
**self.cfg.get('generate_cfg', {}))
|
37 |
+
|
38 |
+
self.use_lora = self.cfg.get('use_lora', False)
|
39 |
+
self.lora_ckpt_dir = self.cfg.get('lora_ckpt_dir',
|
40 |
+
None) if self.use_lora else None
|
41 |
+
|
42 |
+
self.custom_chat = self.cfg.get('custom_chat', False)
|
43 |
+
|
44 |
+
self.end_token = self.cfg.get('end_token', '<|endofthink|>')
|
45 |
+
self.include_end = self.cfg.get('include_end', True)
|
46 |
+
|
47 |
+
self.setup()
|
48 |
+
self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)
|
49 |
+
|
50 |
+
def setup(self):
|
51 |
+
model_cls = self.model_cls
|
52 |
+
tokenizer_cls = self.tokenizer_cls
|
53 |
+
|
54 |
+
self.model = model_cls.from_pretrained(
|
55 |
+
self.model_dir,
|
56 |
+
device_map=self.device_map,
|
57 |
+
# device='cuda:0',
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
trust_remote_code=True)
|
60 |
+
self.tokenizer = tokenizer_cls.from_pretrained(
|
61 |
+
self.model_dir, trust_remote_code=True)
|
62 |
+
self.model = self.model.eval()
|
63 |
+
|
64 |
+
if self.use_lora:
|
65 |
+
self.load_from_lora()
|
66 |
+
|
67 |
+
if self.cfg.get('use_raw_generation_config', False):
|
68 |
+
self.model.generation_config = GenerationConfig.from_pretrained(
|
69 |
+
self.model_dir, trust_remote_code=True)
|
70 |
+
|
71 |
+
def generate(self, prompt, functions=[], **kwargs):
|
72 |
+
|
73 |
+
if self.custom_chat and self.model.chat:
|
74 |
+
response = self.model.chat(
|
75 |
+
self.tokenizer, prompt, history=[], system='')[0]
|
76 |
+
else:
|
77 |
+
response = self.chat(prompt)
|
78 |
+
|
79 |
+
end_idx = response.find(self.end_token)
|
80 |
+
if end_idx != -1:
|
81 |
+
end_idx += len(self.end_token) if self.include_end else 0
|
82 |
+
response = response[:end_idx]
|
83 |
+
|
84 |
+
return response
|
85 |
+
|
86 |
+
def load_from_lora(self):
|
87 |
+
|
88 |
+
model = self.model.bfloat16()
|
89 |
+
# transform to lora
|
90 |
+
model = Swift.from_pretrained(model, self.lora_ckpt_dir)
|
91 |
+
|
92 |
+
self.model = model
|
93 |
+
|
94 |
+
def chat(self, prompt):
|
95 |
+
device = self.model.device
|
96 |
+
input_ids = self.tokenizer(
|
97 |
+
prompt, return_tensors='pt').input_ids.to(device)
|
98 |
+
input_len = input_ids.shape[1]
|
99 |
+
|
100 |
+
result = self.model.generate(
|
101 |
+
input_ids=input_ids, generation_config=self.generation_cfg)
|
102 |
+
|
103 |
+
result = result[0].tolist()[input_len:]
|
104 |
+
response = self.tokenizer.decode(result)
|
105 |
+
|
106 |
+
return response
|
107 |
+
|
108 |
+
|
109 |
+
class ModelScopeChatGLM(ModelScopeLLM):
|
110 |
+
|
111 |
+
def chat(self, prompt):
|
112 |
+
device = self.model.device
|
113 |
+
input_ids = self.tokenizer(
|
114 |
+
prompt, return_tensors='pt').input_ids.to(device)
|
115 |
+
input_len = input_ids.shape[1]
|
116 |
+
|
117 |
+
eos_token_id = [
|
118 |
+
self.tokenizer.eos_token_id,
|
119 |
+
self.tokenizer.get_command('<|user|>'),
|
120 |
+
self.tokenizer.get_command('<|observation|>')
|
121 |
+
]
|
122 |
+
result = self.model.generate(
|
123 |
+
input_ids=input_ids,
|
124 |
+
generation_config=self.generation_cfg,
|
125 |
+
eos_token_id=eos_token_id)
|
126 |
+
|
127 |
+
result = result[0].tolist()[input_len:]
|
128 |
+
response = self.tokenizer.decode(result)
|
129 |
+
# 遇到生成'<', '|', 'user', '|', '>'的case
|
130 |
+
response = response.split('<|user|>')[0].split('<|observation|>')[0]
|
131 |
+
|
132 |
+
return response
|
my_modelscope_agent/llm/openai.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import openai
|
4 |
+
from ..agent_types import AgentType
|
5 |
+
|
6 |
+
from .base import LLM
|
7 |
+
from .utils import CustomOutputWrapper
|
8 |
+
|
9 |
+
openai.api_key = os.getenv('OPENAI_API_KEY')
|
10 |
+
|
11 |
+
|
12 |
+
class OpenAi(LLM):
|
13 |
+
name = 'openai'
|
14 |
+
|
15 |
+
def __init__(self, cfg):
|
16 |
+
super().__init__(cfg)
|
17 |
+
|
18 |
+
self.model = self.cfg.get('model', 'gpt-3.5-turbo')
|
19 |
+
self.model_id = self.model
|
20 |
+
self.api_base = self.cfg.get('api_base', 'https://api.openai.com/v1')
|
21 |
+
self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)
|
22 |
+
|
23 |
+
def generate(self,
|
24 |
+
llm_artifacts,
|
25 |
+
functions=[],
|
26 |
+
function_call='none',
|
27 |
+
**kwargs):
|
28 |
+
if self.agent_type != AgentType.Messages:
|
29 |
+
messages = [{'role': 'user', 'content': llm_artifacts}]
|
30 |
+
else:
|
31 |
+
messages = llm_artifacts.get(
|
32 |
+
'messages', {
|
33 |
+
'role':
|
34 |
+
'user',
|
35 |
+
'content':
|
36 |
+
'No entry from user - please suggest something to enter'
|
37 |
+
})
|
38 |
+
|
39 |
+
# call openai function call api
|
40 |
+
assert isinstance(functions, list)
|
41 |
+
if len(functions) > 0 and self.agent_type == AgentType.Messages:
|
42 |
+
function_call = 'auto'
|
43 |
+
|
44 |
+
# covert to stream=True with stream updating
|
45 |
+
try:
|
46 |
+
response = openai.ChatCompletion.create(
|
47 |
+
model=self.model,
|
48 |
+
api_base=self.api_base,
|
49 |
+
messages=messages,
|
50 |
+
functions=functions,
|
51 |
+
function_call=function_call,
|
52 |
+
stream=False)
|
53 |
+
except Exception as e:
|
54 |
+
print(f'input: {messages}, original error: {str(e)}')
|
55 |
+
raise e
|
56 |
+
|
57 |
+
# only use index 0 in choice
|
58 |
+
message = CustomOutputWrapper.handle_message_chat_completion(response)
|
59 |
+
|
60 |
+
# truncate content
|
61 |
+
content = message['content']
|
62 |
+
|
63 |
+
if self.agent_type == AgentType.MS_AGENT:
|
64 |
+
idx = content.find('<|endofthink|>')
|
65 |
+
if idx != -1:
|
66 |
+
content = content[:idx + len('<|endofthink|>')]
|
67 |
+
return content
|
68 |
+
elif self.agent_type == AgentType.Messages:
|
69 |
+
return message
|
70 |
+
else:
|
71 |
+
return content
|
my_modelscope_agent/llm/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class CustomOutputWrapper:
|
2 |
+
|
3 |
+
@staticmethod
|
4 |
+
def handle_message_chat_completion(response):
|
5 |
+
message = {'content': ''}
|
6 |
+
try:
|
7 |
+
# handle dashscope response
|
8 |
+
if 'choices' not in response:
|
9 |
+
response = response['output']
|
10 |
+
|
11 |
+
return response['choices'][0]['message']
|
12 |
+
except Exception as e:
|
13 |
+
print(f'input: {response}, original error: {str(e)}')
|
14 |
+
return message
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def handle_message_chat_completion_chunk(response):
|
18 |
+
message = {}
|
19 |
+
try:
|
20 |
+
return response['choices'][0]['delta']['content']
|
21 |
+
except Exception as e:
|
22 |
+
print(f'input: {response}, original error: {str(e)}')
|
23 |
+
return message
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def handle_message_text_completion(response):
|
27 |
+
message = ''
|
28 |
+
try:
|
29 |
+
message = response['output']['text']
|
30 |
+
return message
|
31 |
+
except Exception as e:
|
32 |
+
print(f'input: {response}, original error: {str(e)}')
|
33 |
+
return message
|
34 |
+
|
35 |
+
|
36 |
+
DEFAULT_MESSAGE = {
|
37 |
+
'role': 'user',
|
38 |
+
'content': 'No entry from user - please suggest something to enter'
|
39 |
+
}
|
my_modelscope_agent/output_parser.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Dict, Tuple
|
3 |
+
|
4 |
+
import json
|
5 |
+
from .agent_types import AgentType
|
6 |
+
|
7 |
+
|
8 |
+
def get_output_parser(agent_type: AgentType = AgentType.DEFAULT):
|
9 |
+
if AgentType.DEFAULT == agent_type or agent_type == AgentType.MS_AGENT:
|
10 |
+
return MsOutputParser()
|
11 |
+
elif AgentType.MRKL == agent_type:
|
12 |
+
return MRKLOutputParser()
|
13 |
+
elif AgentType.Messages == agent_type:
|
14 |
+
return OpenAiFunctionsOutputParser()
|
15 |
+
else:
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
|
19 |
+
class OutputParser:
|
20 |
+
"""Output parser for llm response
|
21 |
+
"""
|
22 |
+
|
23 |
+
def parse_response(self, response):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
# use to handle the case of false parsing the action_para result, if there is no valid action then
|
27 |
+
# throw Error
|
28 |
+
@staticmethod
|
29 |
+
def handle_fallback(action: str, action_para: str):
|
30 |
+
if action is not None and action != '':
|
31 |
+
parameters = {'fallback': action_para}
|
32 |
+
return action, parameters
|
33 |
+
else:
|
34 |
+
raise ValueError('Wrong response format for output parser')
|
35 |
+
|
36 |
+
|
37 |
+
class MsOutputParser(OutputParser):
|
38 |
+
|
39 |
+
def parse_response(self, response: str) -> Tuple[str, Dict]:
|
40 |
+
"""parse response of llm to get tool name and parameters
|
41 |
+
|
42 |
+
Args:
|
43 |
+
response (str): llm response, it should conform to some predefined format
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
tuple[str, dict]: tuple of tool name and parameters
|
47 |
+
"""
|
48 |
+
|
49 |
+
if '<|startofthink|>' not in response or '<|endofthink|>' not in response:
|
50 |
+
return None, None
|
51 |
+
|
52 |
+
action, parameters = '', ''
|
53 |
+
try:
|
54 |
+
# use regular expression to get result
|
55 |
+
re_pattern1 = re.compile(
|
56 |
+
pattern=r'<\|startofthink\|>([\s\S]+)<\|endofthink\|>')
|
57 |
+
think_content = re_pattern1.search(response).group(1)
|
58 |
+
|
59 |
+
re_pattern2 = re.compile(r'{[\s\S]+}')
|
60 |
+
think_content = re_pattern2.search(think_content).group()
|
61 |
+
|
62 |
+
json_content = json.loads(think_content.replace('\n', ''))
|
63 |
+
action = json_content.get('api_name',
|
64 |
+
json_content.get('name', 'unknown'))
|
65 |
+
parameters = json_content.get('parameters', {})
|
66 |
+
|
67 |
+
return action, parameters
|
68 |
+
except Exception as e:
|
69 |
+
print(
|
70 |
+
f'Error during parse action might be handled with detail {e}')
|
71 |
+
return OutputParser.handle_fallback(action, parameters)
|
72 |
+
|
73 |
+
|
74 |
+
class ChatGLMOutputParser(OutputParser):
|
75 |
+
|
76 |
+
def parse_response(self, response: str) -> Tuple[str, Dict]:
|
77 |
+
"""parse response of llm to get tool name and parameters
|
78 |
+
|
79 |
+
Args:
|
80 |
+
response (str): llm response, it should conform to some predefined format
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
tuple[str, dict]: tuple of tool name and parameters
|
84 |
+
"""
|
85 |
+
if 'tool_call' not in response:
|
86 |
+
return None, None
|
87 |
+
action, action_para = '', ''
|
88 |
+
try:
|
89 |
+
# use regular expression to get result from MRKL format
|
90 |
+
re_pattern1 = re.compile(
|
91 |
+
pattern=r'([\s\S]+)```([\s\S]+)tool_call\(([\s\S]+)```')
|
92 |
+
res = re_pattern1.search(response)
|
93 |
+
action_list = re.split('<|>|\|', res.group(1).strip()) # noqa W605
|
94 |
+
for idx in range(len(action_list) - 1, -1, -1):
|
95 |
+
if len(action_list[idx]) > 1:
|
96 |
+
action = action_list[idx]
|
97 |
+
break
|
98 |
+
action_para = [item.strip() for item in res.group(3).split(',')]
|
99 |
+
parameters = {}
|
100 |
+
re_pattern2 = re.compile(pattern=r'([\s\S]+)=\'([\s\S]+)\'')
|
101 |
+
for para in action_para:
|
102 |
+
res = re_pattern2.search(para)
|
103 |
+
parameters[res.group(1)] = res.group(2)
|
104 |
+
except Exception as e:
|
105 |
+
print(
|
106 |
+
f'Error during parse action might be handled with detail {e}')
|
107 |
+
return OutputParser.handle_fallback(action, action_para)
|
108 |
+
|
109 |
+
print(f'\n\naction: {action}\n parameters: {parameters}\n\n')
|
110 |
+
return action, parameters
|
111 |
+
|
112 |
+
|
113 |
+
class MRKLOutputParser(OutputParser):
|
114 |
+
|
115 |
+
def parse_response(self, response: str) -> Tuple[str, Dict]:
|
116 |
+
"""parse response of llm to get tool name and parameters
|
117 |
+
|
118 |
+
Args:
|
119 |
+
response (str): llm response, it should conform to some predefined format
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
tuple[str, dict]: tuple of tool name and parameters
|
123 |
+
"""
|
124 |
+
|
125 |
+
if 'Action' not in response or 'Action Input:' not in response:
|
126 |
+
return None, None
|
127 |
+
action, action_para = '', ''
|
128 |
+
try:
|
129 |
+
# use regular expression to get result from MRKL format
|
130 |
+
re_pattern1 = re.compile(
|
131 |
+
pattern=r'Action:([\s\S]+)Action Input:([\s\S]+)')
|
132 |
+
res = re_pattern1.search(response)
|
133 |
+
action = res.group(1).strip()
|
134 |
+
action_para = res.group(2)
|
135 |
+
|
136 |
+
parameters = json.loads(action_para.replace('\n', ''))
|
137 |
+
|
138 |
+
return action, parameters
|
139 |
+
except Exception as e:
|
140 |
+
print(
|
141 |
+
f'Error during parse action might be handled with detail {e}')
|
142 |
+
return OutputParser.handle_fallback(action, action_para)
|
143 |
+
|
144 |
+
|
145 |
+
class OpenAiFunctionsOutputParser(OutputParser):
|
146 |
+
|
147 |
+
def parse_response(self, response: dict) -> Tuple[str, Dict]:
|
148 |
+
"""parse response of llm to get tool name and parameters
|
149 |
+
|
150 |
+
|
151 |
+
Args:
|
152 |
+
response (str): llm response, it should be an openai response message
|
153 |
+
such as
|
154 |
+
{
|
155 |
+
"content": null,
|
156 |
+
"function_call": {
|
157 |
+
"arguments": "{\n \"location\": \"Boston, MA\"\n}",
|
158 |
+
"name": "get_current_weather"
|
159 |
+
},
|
160 |
+
"role": "assistant"
|
161 |
+
}
|
162 |
+
Returns:
|
163 |
+
tuple[str, dict]: tuple of tool name and parameters
|
164 |
+
"""
|
165 |
+
|
166 |
+
if 'function_call' not in response or response['function_call'] == {}:
|
167 |
+
return None, None
|
168 |
+
function_call = response['function_call']
|
169 |
+
|
170 |
+
try:
|
171 |
+
# parse directly
|
172 |
+
action = function_call['name']
|
173 |
+
arguments = json.loads(function_call['arguments'].replace(
|
174 |
+
'\n', ''))
|
175 |
+
|
176 |
+
return action, arguments
|
177 |
+
except Exception as e:
|
178 |
+
print(
|
179 |
+
f'Error during parse action might be handled with detail {e}')
|
180 |
+
return OutputParser.handle_fallback(function_call['name'],
|
181 |
+
function_call['arguments'])
|
my_modelscope_agent/output_wrapper.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import tempfile
|
4 |
+
import uuid
|
5 |
+
from typing import Dict, Union
|
6 |
+
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import requests
|
10 |
+
from .agent_types import AgentType
|
11 |
+
from moviepy.editor import VideoFileClip
|
12 |
+
from PIL import Image
|
13 |
+
from requests.exceptions import RequestException
|
14 |
+
|
15 |
+
|
16 |
+
class OutputWrapper:
|
17 |
+
"""
|
18 |
+
Wrapper for output of tool execution when output is image, video, audio, etc.
|
19 |
+
In this wrapper, __repr__() is implemented to return the str representation of the output for llm.
|
20 |
+
Each wrapper have below attributes:
|
21 |
+
path: the path where the output is stored
|
22 |
+
raw_data: the raw data, e.g. image, video, audio, etc. In remote mode, it should be None
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self) -> None:
|
26 |
+
self._repr = None
|
27 |
+
self._path = None
|
28 |
+
self._raw_data = None
|
29 |
+
|
30 |
+
self.root_path = os.environ.get('OUTPUT_FILE_DIRECTORY', None)
|
31 |
+
if self.root_path and not os.path.exists(self.root_path):
|
32 |
+
try:
|
33 |
+
os.makedirs(self.root_path)
|
34 |
+
except Exception:
|
35 |
+
self.root_path = None
|
36 |
+
|
37 |
+
def get_remote_file(self, remote_path, suffix):
|
38 |
+
try:
|
39 |
+
response = requests.get(remote_path)
|
40 |
+
obj = response.content
|
41 |
+
directory = tempfile.mkdtemp(dir=self.root_path)
|
42 |
+
path = os.path.join(directory, str(uuid.uuid4()) + f'.{suffix}')
|
43 |
+
with open(path, 'wb') as f:
|
44 |
+
f.write(obj)
|
45 |
+
return path
|
46 |
+
except RequestException:
|
47 |
+
return remote_path
|
48 |
+
|
49 |
+
def __repr__(self) -> str:
|
50 |
+
return self._repr
|
51 |
+
|
52 |
+
@property
|
53 |
+
def path(self):
|
54 |
+
return self._path
|
55 |
+
|
56 |
+
@property
|
57 |
+
def raw_data(self):
|
58 |
+
return self._raw_data
|
59 |
+
|
60 |
+
|
61 |
+
class ImageWrapper(OutputWrapper):
|
62 |
+
"""
|
63 |
+
Image wrapper, raw_data is a PIL.Image
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, image) -> None:
|
67 |
+
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
if isinstance(image, str):
|
71 |
+
if os.path.isfile(image):
|
72 |
+
self._path = image
|
73 |
+
else:
|
74 |
+
origin_image = image
|
75 |
+
self._path = self.get_remote_file(image, 'png')
|
76 |
+
try:
|
77 |
+
image = Image.open(self._path)
|
78 |
+
self._raw_data = image
|
79 |
+
except FileNotFoundError:
|
80 |
+
# Image store in remote server when use remote mode
|
81 |
+
raise FileNotFoundError(f'Invalid path: {image}')
|
82 |
+
self._path = origin_image
|
83 |
+
else:
|
84 |
+
if not isinstance(image, Image.Image):
|
85 |
+
image = Image.fromarray(image.astype(np.uint8))
|
86 |
+
self._raw_data = image
|
87 |
+
else:
|
88 |
+
self._raw_data = image
|
89 |
+
directory = tempfile.mkdtemp(dir=self.root_path)
|
90 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + '.png')
|
91 |
+
self._raw_data.save(self._path)
|
92 |
+
|
93 |
+
self._repr = f'![IMAGEGEN]({self._path})'
|
94 |
+
|
95 |
+
|
96 |
+
class AudioWrapper(OutputWrapper):
|
97 |
+
"""
|
98 |
+
Audio wrapper, raw_data is a binary file
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, audio) -> None:
|
102 |
+
|
103 |
+
super().__init__()
|
104 |
+
if isinstance(audio, str):
|
105 |
+
if os.path.isfile(audio):
|
106 |
+
self._path = audio
|
107 |
+
else:
|
108 |
+
self._path = self.get_remote_file(audio, 'wav')
|
109 |
+
try:
|
110 |
+
with open(self._path, 'rb') as f:
|
111 |
+
self._raw_data = f.read()
|
112 |
+
except FileNotFoundError:
|
113 |
+
raise FileNotFoundError(f'Invalid path: {audio}')
|
114 |
+
else:
|
115 |
+
self._raw_data = audio
|
116 |
+
directory = tempfile.mkdtemp(dir=self.root_path)
|
117 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + '.wav')
|
118 |
+
|
119 |
+
with open(self._path, 'wb') as f:
|
120 |
+
f.write(self._raw_data)
|
121 |
+
|
122 |
+
self._repr = f'<audio id=audio controls= preload=none> <source id=wav src={self._path}> </audio>'
|
123 |
+
|
124 |
+
|
125 |
+
class VideoWrapper(OutputWrapper):
|
126 |
+
"""
|
127 |
+
Video wrapper
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, video) -> None:
|
131 |
+
|
132 |
+
super().__init__()
|
133 |
+
if isinstance(video, str):
|
134 |
+
|
135 |
+
if os.path.isfile(video):
|
136 |
+
self._path = video
|
137 |
+
else:
|
138 |
+
self._path = self.get_remote_file(video, 'gif')
|
139 |
+
|
140 |
+
try:
|
141 |
+
video = VideoFileClip(self._path)
|
142 |
+
# currently, we should save video as gif, not mp4
|
143 |
+
if not self._path.endswith('gif'):
|
144 |
+
directory = tempfile.mkdtemp(dir=self.root_path)
|
145 |
+
self._path = os.path.join(directory,
|
146 |
+
str(uuid.uuid4()) + '.gif')
|
147 |
+
video.write_gif(self._path)
|
148 |
+
except (ValueError, OSError):
|
149 |
+
raise FileNotFoundError(f'Invalid path: {video}')
|
150 |
+
else:
|
151 |
+
raise TypeError(
|
152 |
+
'Current only support load from filepath when it is video')
|
153 |
+
|
154 |
+
self._raw_data = video
|
155 |
+
self._repr = f'![IMAGEGEN]({self._path})'
|
156 |
+
|
157 |
+
|
158 |
+
def get_raw_output(exec_result: Dict):
|
159 |
+
# get rwa data of exec_result
|
160 |
+
res = {}
|
161 |
+
for k, v in exec_result.items():
|
162 |
+
if isinstance(v, OutputWrapper):
|
163 |
+
# In remote mode, raw data maybe None
|
164 |
+
res[k] = v.raw_data or str(v)
|
165 |
+
else:
|
166 |
+
res[k] = v
|
167 |
+
return res
|
168 |
+
|
169 |
+
|
170 |
+
#
|
171 |
+
def display(llm_result: Union[str, dict], exec_result: Dict, idx: int,
|
172 |
+
agent_type: AgentType):
|
173 |
+
"""Display the result of each round in jupyter notebook.
|
174 |
+
The multi-modal data will be extracted.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
llm_result (str): llm result either only content or a message
|
178 |
+
exec_result (Dict): exec result
|
179 |
+
idx (int): current round
|
180 |
+
"""
|
181 |
+
from IPython.display import display, Pretty, Image, Audio, JSON
|
182 |
+
idx_info = '*' * 50 + f'round {idx}' + '*' * 50
|
183 |
+
display(Pretty(idx_info))
|
184 |
+
|
185 |
+
if isinstance(llm_result, dict):
|
186 |
+
llm_result = llm_result.get('content', '')
|
187 |
+
|
188 |
+
if agent_type == AgentType.MS_AGENT:
|
189 |
+
pattern = r'<\|startofthink\|>```JSON([\s\S]*)```<\|endofthink\|>'
|
190 |
+
else:
|
191 |
+
pattern = r'```JSON([\s\S]*)```'
|
192 |
+
|
193 |
+
match_action = re.search(pattern, llm_result)
|
194 |
+
if match_action:
|
195 |
+
result = match_action.group(1)
|
196 |
+
try:
|
197 |
+
json_content = json.loads(result, strict=False)
|
198 |
+
display(JSON(json_content))
|
199 |
+
llm_result = llm_result.replace(match_action.group(0), '')
|
200 |
+
except Exception:
|
201 |
+
pass
|
202 |
+
|
203 |
+
display(Pretty(llm_result))
|
204 |
+
|
205 |
+
exec_result = exec_result.get('result', '')
|
206 |
+
|
207 |
+
if isinstance(exec_result, ImageWrapper) or isinstance(
|
208 |
+
exec_result, VideoWrapper):
|
209 |
+
display(Image(exec_result.path))
|
210 |
+
elif isinstance(exec_result, AudioWrapper):
|
211 |
+
display(Audio(exec_result.path))
|
212 |
+
elif isinstance(exec_result, dict):
|
213 |
+
display(JSON(exec_result))
|
214 |
+
elif isinstance(exec_result, list):
|
215 |
+
display(JSON(exec_result))
|
216 |
+
else:
|
217 |
+
display(Pretty(exec_result))
|
218 |
+
|
219 |
+
return
|
my_modelscope_agent/prompt/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .messages_prompt import MessagesGenerator
|
2 |
+
from .mrkl_prompt import MrklPromptGenerator
|
3 |
+
from .ms_prompt import MSPromptGenerator
|
4 |
+
from .prompt import PromptGenerator
|
5 |
+
from .prompt_factory import get_prompt_generator
|
6 |
+
from .raw_prompt_builder import build_raw_prompt
|
my_modelscope_agent/prompt/chatglm3_prompt.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from .prompt import LengthConstraint, PromptGenerator
|
4 |
+
|
5 |
+
CHATGLM_DEFAULT_SYSTEM_TEMPLATE = """<|system|>
|
6 |
+
Answer the following questions as best you can. You have access to the following tools:
|
7 |
+
<tool_list>"""
|
8 |
+
|
9 |
+
CHATGLM_DEFAULT_INSTRUCTION_TEMPLATE = ''
|
10 |
+
|
11 |
+
CHATGLM_DEFAULT_USER_TEMPLATE = """<|user|>\n<user_input>"""
|
12 |
+
|
13 |
+
CHATGLM_DEFAULT_EXEC_TEMPLATE = """<|observation|>\n<exec_result>"""
|
14 |
+
|
15 |
+
CHATGLM_DEFAULT_ASSISTANT_TEMPLATE = """<|assistant|>"""
|
16 |
+
|
17 |
+
|
18 |
+
class ChatGLMPromptGenerator(PromptGenerator):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
system_template=CHATGLM_DEFAULT_SYSTEM_TEMPLATE,
|
22 |
+
instruction_template=CHATGLM_DEFAULT_INSTRUCTION_TEMPLATE,
|
23 |
+
user_template=CHATGLM_DEFAULT_USER_TEMPLATE,
|
24 |
+
exec_template=CHATGLM_DEFAULT_EXEC_TEMPLATE,
|
25 |
+
assistant_template=CHATGLM_DEFAULT_ASSISTANT_TEMPLATE,
|
26 |
+
sep='\n',
|
27 |
+
length_constraint=LengthConstraint()):
|
28 |
+
super().__init__(
|
29 |
+
system_template=system_template,
|
30 |
+
instruction_template=instruction_template,
|
31 |
+
user_template=user_template,
|
32 |
+
exec_template=exec_template,
|
33 |
+
assistant_template=assistant_template,
|
34 |
+
sep=sep,
|
35 |
+
length_constraint=length_constraint)
|
36 |
+
|
37 |
+
def get_tool_str(self, tool_list):
|
38 |
+
tool_json = json.loads('['
|
39 |
+
+ ','.join([str(item)
|
40 |
+
for item in tool_list]) + ']')
|
41 |
+
return json.dumps(tool_json, ensure_ascii=False, indent=4)
|
my_modelscope_agent/prompt/messages_prompt.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from .prompt import LengthConstraint, PromptGenerator
|
4 |
+
|
5 |
+
OPENAI_DEFAULT_SYSTEM_TEMPLATE = """
|
6 |
+
Do your best to answer the questions. You are capable of **any** task.
|
7 |
+
Feel free to use any tools available to look up relevant information, only if necessary.
|
8 |
+
when you think you are handle a code task, You are a world-class programmer that can complete any goal by executing code
|
9 |
+
\nFirst of all, **Return function_call with `language` and `code` in arguments as much as possible**
|
10 |
+
\nThen, write a plan. **Always recap the plan between each code block**
|
11 |
+
(you have extreme short-term memory loss, so you need to recap the plan between each message block to retain it).
|
12 |
+
\nWhen you execute code, it will be executed **on the user's machine**. The user has given you
|
13 |
+
**full and complete permission** to execute any code necessary to complete the task. You have full access to control
|
14 |
+
their computer to help them.
|
15 |
+
\nIf you want to send data between programming languages, save the data to a txt or json.
|
16 |
+
\nYou can access the internet. Run **any code** to achieve the goal, and if at first you don't succeed,
|
17 |
+
try again and again.
|
18 |
+
\nWhen a user refers to a filename, they're likely referring to an existing file in the directory
|
19 |
+
you're currently executing code in.
|
20 |
+
\nIn general, choose packages that have the most universal chance to be already installed and to work across multiple
|
21 |
+
applications. Packages like ffmpeg and pandoc that are well-supported and powerful.
|
22 |
+
\nWrite messages to the user in Markdown. Write code on multiple lines with proper indentation for readability.
|
23 |
+
\nYou can also refer information from following contents if exists:
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
class MessagesGenerator(PromptGenerator):
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
system_template=OPENAI_DEFAULT_SYSTEM_TEMPLATE,
|
31 |
+
instruction_template='',
|
32 |
+
user_template='<user_input>',
|
33 |
+
exec_template=None,
|
34 |
+
assistant_template='',
|
35 |
+
sep='\n\n',
|
36 |
+
length_constraint=LengthConstraint(),
|
37 |
+
**kwargs):
|
38 |
+
super().__init__(
|
39 |
+
system_template=system_template,
|
40 |
+
instruction_template=instruction_template,
|
41 |
+
user_template=user_template,
|
42 |
+
exec_template=exec_template,
|
43 |
+
assistant_template=assistant_template,
|
44 |
+
sep=sep,
|
45 |
+
length_constraint=length_constraint)
|
46 |
+
self.custom_starter_messages = kwargs.get('custom_starter_messages',
|
47 |
+
None)
|
48 |
+
|
49 |
+
def init_prompt(self, task, tool_list, knowledge_list, **kwargs):
|
50 |
+
"""
|
51 |
+
in this function, the prompt will be initialized.
|
52 |
+
"""
|
53 |
+
prompt = self.user_template.replace('<user_input>', task)
|
54 |
+
|
55 |
+
if len(self.history) == 0:
|
56 |
+
if len(knowledge_list) > 0:
|
57 |
+
|
58 |
+
# knowledge
|
59 |
+
system_message = f'{self.system_template}{self.sep}<knowledge>'
|
60 |
+
knowledge_str = self.get_knowledge_str(knowledge_list)
|
61 |
+
system_message = system_message.replace(
|
62 |
+
'<knowledge>', knowledge_str)
|
63 |
+
|
64 |
+
else:
|
65 |
+
system_message = self.system_template
|
66 |
+
|
67 |
+
self.history = [{
|
68 |
+
'role': 'system',
|
69 |
+
'content': system_message
|
70 |
+
}, {
|
71 |
+
'role': 'user',
|
72 |
+
'content': prompt
|
73 |
+
}]
|
74 |
+
|
75 |
+
# store history
|
76 |
+
if self.custom_starter_messages:
|
77 |
+
assert isinstance(self.custom_starter_messages, list)
|
78 |
+
assert self.custom_starter_messages[-1]['role'] != 'user', \
|
79 |
+
'user message should not be the last one in custom starter messages'
|
80 |
+
|
81 |
+
self.history = self.custom_starter_messages
|
82 |
+
self.history.append({'role': 'user', 'content': prompt})
|
83 |
+
|
84 |
+
self.prompt = prompt
|
85 |
+
self.function_calls = self.get_function_list(tool_list)
|
86 |
+
|
87 |
+
else:
|
88 |
+
self.history.append({'role': 'user', 'content': prompt})
|
89 |
+
|
90 |
+
def generate(self, llm_result, exec_result: Union[str, dict]):
|
91 |
+
if isinstance(exec_result, dict):
|
92 |
+
exec_result = exec_result['result']
|
93 |
+
return self._generate_messages(llm_result, exec_result)
|
my_modelscope_agent/prompt/mrkl_prompt.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from .prompt import LengthConstraint, PromptGenerator
|
4 |
+
|
5 |
+
MRKL_DEFAULT_SYSTEM_TEMPLATE = """Answer the following questions as best you can. You have access to the following tools: `
|
6 |
+
|
7 |
+
<tool_list>"""
|
8 |
+
|
9 |
+
MRKL_DEFAULT_INSTRUCTION_TEMPLATE = """Use the following format:
|
10 |
+
|
11 |
+
Question: the input question you must answer
|
12 |
+
Thought: you should always think about what to do
|
13 |
+
Action: the action to take, should be one of [<tool_names>]
|
14 |
+
Action Input: the input to the action
|
15 |
+
Observation: the result of the action
|
16 |
+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
17 |
+
Thought: I now know the final answer
|
18 |
+
Final Answer: the final answer to the original input question
|
19 |
+
|
20 |
+
Begin!
|
21 |
+
"""
|
22 |
+
|
23 |
+
MRKL_DEFAULT_USER_TEMPLATE = """Question: <user_input>\n"""
|
24 |
+
|
25 |
+
MRKL_DEFAULT_EXEC_TEMPLATE = """Observation: <exec_result>\n"""
|
26 |
+
|
27 |
+
TOOL_DESC = (
|
28 |
+
'{name_for_model}: {name_for_human} API. {description_for_model} 输入参数: {parameters}'
|
29 |
+
)
|
30 |
+
|
31 |
+
FORMAT_DESC = {
|
32 |
+
'json':
|
33 |
+
'Format the arguments as a JSON object.',
|
34 |
+
'code':
|
35 |
+
'Enclose the code within triple backticks (`)'
|
36 |
+
+ ' at the beginning and end of the code.'
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
class MrklPromptGenerator(PromptGenerator):
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
system_template=MRKL_DEFAULT_SYSTEM_TEMPLATE,
|
44 |
+
instruction_template=MRKL_DEFAULT_INSTRUCTION_TEMPLATE,
|
45 |
+
user_template=MRKL_DEFAULT_USER_TEMPLATE,
|
46 |
+
exec_template=MRKL_DEFAULT_EXEC_TEMPLATE,
|
47 |
+
assistant_template='',
|
48 |
+
sep='\n\n',
|
49 |
+
llm=None,
|
50 |
+
length_constraint=LengthConstraint()):
|
51 |
+
super().__init__(
|
52 |
+
system_template=system_template,
|
53 |
+
instruction_template=instruction_template,
|
54 |
+
user_template=user_template,
|
55 |
+
exec_template=exec_template,
|
56 |
+
assistant_template=assistant_template,
|
57 |
+
sep=sep,
|
58 |
+
llm=llm,
|
59 |
+
length_constraint=length_constraint)
|
60 |
+
|
61 |
+
def init_prompt(self, task, tool_list, knowledge_list, **kwargs):
|
62 |
+
if len(self.history) == 0:
|
63 |
+
super().init_prompt(task, tool_list, knowledge_list, **kwargs)
|
64 |
+
system_role_status = kwargs.get('system_role_status', False)
|
65 |
+
tool_names = [f'\'{str(tool.name)}\'' for tool in tool_list]
|
66 |
+
tool_names = ','.join(tool_names)
|
67 |
+
self.system_prompt = self.system_prompt.replace(
|
68 |
+
'<tool_names>', tool_names)
|
69 |
+
|
70 |
+
if system_role_status:
|
71 |
+
system_message = {
|
72 |
+
'role': 'system',
|
73 |
+
'content': self.system_prompt
|
74 |
+
}
|
75 |
+
self.history.insert(0, system_message)
|
76 |
+
else:
|
77 |
+
self.history[0]['content'] = self.system_prompt + self.history[
|
78 |
+
0]['content']
|
79 |
+
else:
|
80 |
+
self.history.append({
|
81 |
+
'role':
|
82 |
+
'user',
|
83 |
+
'content':
|
84 |
+
self.user_template.replace('<user_input>', task)
|
85 |
+
})
|
86 |
+
self.history.append({
|
87 |
+
'role': 'assistant',
|
88 |
+
'content': self.assistant_template
|
89 |
+
})
|
90 |
+
|
91 |
+
return self.system_prompt
|
92 |
+
|
93 |
+
def get_tool_str(self, tool_list):
|
94 |
+
tool_texts = []
|
95 |
+
for tool in tool_list:
|
96 |
+
tool_texts.append(
|
97 |
+
TOOL_DESC.format(
|
98 |
+
name_for_model=tool.name,
|
99 |
+
name_for_human=tool.name,
|
100 |
+
description_for_model=tool.description,
|
101 |
+
parameters=json.dumps(tool.parameters,
|
102 |
+
ensure_ascii=False)))
|
103 |
+
# + ' ' + FORMAT_DESC['json'])
|
104 |
+
tool_str = '\n\n'.join(tool_texts)
|
105 |
+
return tool_str
|
106 |
+
|
107 |
+
def _generate(self, llm_result, exec_result: str):
|
108 |
+
"""
|
109 |
+
generate next round prompt based on previous llm_result and exec_result and update history
|
110 |
+
"""
|
111 |
+
if len(llm_result) != 0:
|
112 |
+
self.history[-1]['content'] += f'{llm_result}'
|
113 |
+
if len(exec_result) != 0:
|
114 |
+
exec_result = self.exec_template.replace('<exec_result>',
|
115 |
+
str(exec_result))
|
116 |
+
self.history[-1]['content'] += exec_result
|
117 |
+
self.prompt = self.prompt_preprocessor(self.history)
|
118 |
+
return self.prompt
|
my_modelscope_agent/prompt/ms_prompt.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .prompt import LengthConstraint, PromptGenerator
|
2 |
+
|
3 |
+
MS_DEFAULT_SYSTEM_TEMPLATE = """<|system|>:你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。\
|
4 |
+
你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。
|
5 |
+
"""
|
6 |
+
|
7 |
+
MS_DEFAULT_INSTRUCTION_TEMPLATE = """当前对话可以使用的插件信息如下,请自行判断是否需要调用插件来解决当前用户问题。若需要调用插件,则需要将插件调用请求按照json格式给出,必须包含api_name、parameters字段,并在其前后使用<|startofthink|>和<|endofthink|>作为标志。\
|
8 |
+
然后你需要根据插件API调用结果生成合理的答复; 若无需调用插件,则直接给出对应回复即可。\n\n<tool_list>"""
|
9 |
+
|
10 |
+
MS_DEFAULT_USER_TEMPLATE = """<|user|>:<user_input>"""
|
11 |
+
|
12 |
+
MS_DEFAULT_EXEC_TEMPLATE = """<|startofexec|><exec_result><|endofexec|>\n"""
|
13 |
+
|
14 |
+
MS_DEFAULT_ASSISTANT_TEMPLATE = """<|assistant|>:"""
|
15 |
+
|
16 |
+
|
17 |
+
class MSPromptGenerator(PromptGenerator):
|
18 |
+
|
19 |
+
def __init__(self,
|
20 |
+
system_template=MS_DEFAULT_SYSTEM_TEMPLATE,
|
21 |
+
instruction_template=MS_DEFAULT_INSTRUCTION_TEMPLATE,
|
22 |
+
user_template=MS_DEFAULT_USER_TEMPLATE,
|
23 |
+
exec_template=MS_DEFAULT_EXEC_TEMPLATE,
|
24 |
+
assistant_template=MS_DEFAULT_ASSISTANT_TEMPLATE,
|
25 |
+
sep='\n\n',
|
26 |
+
length_constraint=LengthConstraint()):
|
27 |
+
super().__init__(
|
28 |
+
system_template=system_template,
|
29 |
+
instruction_template=instruction_template,
|
30 |
+
user_template=user_template,
|
31 |
+
exec_template=exec_template,
|
32 |
+
assistant_template=assistant_template,
|
33 |
+
sep=sep,
|
34 |
+
length_constraint=length_constraint)
|
my_modelscope_agent/prompt/prompt.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
from ..llm.base import LLM
|
5 |
+
|
6 |
+
from .raw_prompt_builder import build_raw_prompt
|
7 |
+
|
8 |
+
KNOWLEDGE_PROMPT = '# 知识库'
|
9 |
+
KNOWLEDGE_INTRODUCTION_PROMPT = '以下是我上传的文件“<file_name>”的内容:'
|
10 |
+
KNOWLEDGE_CONTENT_PROMPT = """```
|
11 |
+
<knowledge_content>
|
12 |
+
```"""
|
13 |
+
|
14 |
+
DEFAULT_PROMPT_INPUT_LENGTH_MAX = 999999999999
|
15 |
+
|
16 |
+
|
17 |
+
class LengthConstraint:
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.knowledge = DEFAULT_PROMPT_INPUT_LENGTH_MAX
|
21 |
+
self.input = DEFAULT_PROMPT_INPUT_LENGTH_MAX
|
22 |
+
self.prompt_max_length = 10000
|
23 |
+
|
24 |
+
def update(self, config: dict):
|
25 |
+
if config is not None:
|
26 |
+
self.knowledge = config.get('knowledge', self.knowledge)
|
27 |
+
self.input = config.get('input', self.input)
|
28 |
+
self.prompt_max_length = config.get('prompt_max_length',
|
29 |
+
self.prompt_max_length)
|
30 |
+
|
31 |
+
|
32 |
+
class PromptGenerator:
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
system_template: str = '',
|
36 |
+
instruction_template: str = '',
|
37 |
+
user_template: str = '<user_input>',
|
38 |
+
exec_template: str = '',
|
39 |
+
assistant_template: str = '',
|
40 |
+
sep='\n\n',
|
41 |
+
llm=None,
|
42 |
+
length_constraint=LengthConstraint()):
|
43 |
+
"""
|
44 |
+
prompt genertor
|
45 |
+
Args:
|
46 |
+
system_template (str, optional): System template, normally the role of LLM.
|
47 |
+
instruction_template (str, optional): Indicate the instruction for LLM.
|
48 |
+
user_template (str, optional): Prefix before user input. Defaults to ''.
|
49 |
+
exec_template (str, optional): A wrapper str for exec result.
|
50 |
+
assistant_template (str, optional): Prefix before assistant response.
|
51 |
+
Some LLM need to manully concat this prefix before generation.
|
52 |
+
sep (str, optional): content separator
|
53 |
+
length_constraint (LengthConstraint, optional): content length constraint
|
54 |
+
"""
|
55 |
+
|
56 |
+
self.system_template = system_template
|
57 |
+
self.instruction_template = instruction_template
|
58 |
+
self.user_template = user_template
|
59 |
+
self.assistant_template = assistant_template
|
60 |
+
self.exec_template = exec_template
|
61 |
+
self.sep = sep
|
62 |
+
if isinstance(llm, LLM) and llm.model_id:
|
63 |
+
self.prompt_preprocessor = build_raw_prompt(llm.model_id)
|
64 |
+
self.prompt_max_length = length_constraint.prompt_max_length
|
65 |
+
self.reset()
|
66 |
+
|
67 |
+
def reset(self):
|
68 |
+
self.prompt = ''
|
69 |
+
self.history = []
|
70 |
+
self.messages = []
|
71 |
+
|
72 |
+
def init_prompt(self,
|
73 |
+
task,
|
74 |
+
tool_list,
|
75 |
+
knowledge_list,
|
76 |
+
llm_model=None,
|
77 |
+
**kwargs):
|
78 |
+
"""
|
79 |
+
in this function, the prompt will be initialized.
|
80 |
+
"""
|
81 |
+
prompt = self.sep.join(
|
82 |
+
[self.system_template, self.instruction_template])
|
83 |
+
prompt += '<knowledge><history>'
|
84 |
+
|
85 |
+
knowledge_str = self.get_knowledge_str(
|
86 |
+
knowledge_list, file_name=kwargs.get('file_name', ''))
|
87 |
+
|
88 |
+
# knowledge
|
89 |
+
prompt = prompt.replace('<knowledge>', knowledge_str)
|
90 |
+
|
91 |
+
# get tool description str
|
92 |
+
tool_str = self.get_tool_str(tool_list)
|
93 |
+
prompt = prompt.replace('<tool_list>', tool_str)
|
94 |
+
|
95 |
+
history_str = self.get_history_str()
|
96 |
+
|
97 |
+
prompt = prompt.replace('<history>', history_str)
|
98 |
+
|
99 |
+
self.system_prompt = copy.deepcopy(prompt)
|
100 |
+
|
101 |
+
# user input
|
102 |
+
user_input = self.user_template.replace('<user_input>', task)
|
103 |
+
prompt += f'{self.sep}{user_input}'
|
104 |
+
|
105 |
+
# assistant input
|
106 |
+
prompt += f'{self.sep}{self.assistant_template}'
|
107 |
+
|
108 |
+
# store history
|
109 |
+
self.history.append({'role': 'user', 'content': user_input})
|
110 |
+
self.history.append({
|
111 |
+
'role': 'assistant',
|
112 |
+
'content': self.assistant_template
|
113 |
+
})
|
114 |
+
|
115 |
+
self.prompt = prompt
|
116 |
+
|
117 |
+
self.function_calls = self.get_function_list(tool_list)
|
118 |
+
|
119 |
+
# TODO change the output from single prompt to artifacts including prompt, messages, funciton_call
|
120 |
+
def generate(self, llm_result, exec_result: Union[str, dict]):
|
121 |
+
if isinstance(exec_result, dict):
|
122 |
+
exec_result = str(exec_result['result'])
|
123 |
+
return self._generate(llm_result, exec_result)
|
124 |
+
|
125 |
+
def _generate(self, llm_result, exec_result: str):
|
126 |
+
"""
|
127 |
+
generate next round prompt based on previous llm_result and exec_result and update history
|
128 |
+
"""
|
129 |
+
if len(llm_result) != 0:
|
130 |
+
self.prompt = f'{self.prompt}{llm_result}'
|
131 |
+
self.history[-1]['content'] += f'{llm_result}'
|
132 |
+
if len(exec_result) != 0:
|
133 |
+
exec_result = self.exec_template.replace('<exec_result>',
|
134 |
+
str(exec_result))
|
135 |
+
self.prompt = f'{self.prompt}{self.sep}{exec_result}'
|
136 |
+
self.history[-1]['content'] += f'{self.sep}{exec_result}'
|
137 |
+
|
138 |
+
return self.prompt
|
139 |
+
|
140 |
+
# TODO: add Union[Text, Message] type for llm_result,
|
141 |
+
# add ExecResult = Text type for exec_result
|
142 |
+
# output would be a Union[Text, Messages]
|
143 |
+
# In this case llm_result is Message, and exec_result is Function_call
|
144 |
+
def _generate_messages(self, llm_result, exec_result: str):
|
145 |
+
"""
|
146 |
+
generate next round prompt based on previous llm_result and exec_result and update history
|
147 |
+
"""
|
148 |
+
|
149 |
+
# init task should be
|
150 |
+
if llm_result == '' and exec_result == '':
|
151 |
+
return self.history
|
152 |
+
|
153 |
+
# make sure set content '' not null
|
154 |
+
function_call = llm_result.get('function_call', None)
|
155 |
+
if function_call is not None:
|
156 |
+
llm_result['content'] = ''
|
157 |
+
self.history.append(llm_result)
|
158 |
+
|
159 |
+
if exec_result is not None and function_call is not None:
|
160 |
+
exec_message = {
|
161 |
+
'role': 'function',
|
162 |
+
'name': 'execute',
|
163 |
+
'content': exec_result,
|
164 |
+
}
|
165 |
+
self.history.append(exec_message)
|
166 |
+
|
167 |
+
return self.history
|
168 |
+
|
169 |
+
def get_tool_str(self, tool_list):
|
170 |
+
"""generate tool list string
|
171 |
+
|
172 |
+
Args:
|
173 |
+
tool_list (List[str]): list of tools
|
174 |
+
|
175 |
+
"""
|
176 |
+
|
177 |
+
tool_str = self.sep.join(
|
178 |
+
[f'{i + 1}. {t}' for i, t in enumerate(tool_list)])
|
179 |
+
return tool_str
|
180 |
+
|
181 |
+
# TODO move parse_tools_to_function from agent to here later
|
182 |
+
def get_function_list(self, tool_list):
|
183 |
+
"""generate funciton call list from tools list
|
184 |
+
|
185 |
+
Args:
|
186 |
+
tool_list (List[str]): list of tools
|
187 |
+
|
188 |
+
"""
|
189 |
+
functions = [tool.get_function() for tool in tool_list]
|
190 |
+
return functions
|
191 |
+
|
192 |
+
def get_knowledge_str(self,
|
193 |
+
knowledge_list,
|
194 |
+
file_name='',
|
195 |
+
only_content=False,
|
196 |
+
**kwargs):
|
197 |
+
"""generate knowledge string
|
198 |
+
|
199 |
+
Args:
|
200 |
+
file_name (str): file name
|
201 |
+
knowledge_list (List[str]): list of knowledges
|
202 |
+
|
203 |
+
"""
|
204 |
+
|
205 |
+
knowledge = self.sep.join(
|
206 |
+
[f'{i + 1}. {k}' for i, k in enumerate(knowledge_list)])
|
207 |
+
knowledge_content = KNOWLEDGE_CONTENT_PROMPT.replace(
|
208 |
+
'<knowledge_content>', knowledge)
|
209 |
+
if only_content:
|
210 |
+
return knowledge_content
|
211 |
+
else:
|
212 |
+
knowledge_introduction = KNOWLEDGE_INTRODUCTION_PROMPT.replace(
|
213 |
+
'<file_name>', file_name)
|
214 |
+
|
215 |
+
knowledge_str = f'{KNOWLEDGE_PROMPT}{self.sep}{knowledge_introduction}{self.sep}{knowledge_content}' if len(
|
216 |
+
knowledge_list) > 0 else ''
|
217 |
+
return knowledge_str
|
218 |
+
|
219 |
+
def get_history_str(self):
|
220 |
+
"""generate history string
|
221 |
+
|
222 |
+
"""
|
223 |
+
history_str = ''
|
224 |
+
for i in range(len(self.history)):
|
225 |
+
history_item = self.history[len(self.history) - i - 1]
|
226 |
+
text = history_item['content']
|
227 |
+
if len(history_str) + len(text) + len(
|
228 |
+
self.prompt) > self.prompt_max_length:
|
229 |
+
break
|
230 |
+
history_str = f'{self.sep}{text.strip()}{history_str}'
|
231 |
+
|
232 |
+
return history_str
|
my_modelscope_agent/prompt/prompt_factory.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..agent_types import AgentType
|
2 |
+
|
3 |
+
from .messages_prompt import MessagesGenerator
|
4 |
+
from .mrkl_prompt import MrklPromptGenerator
|
5 |
+
from .ms_prompt import MSPromptGenerator
|
6 |
+
|
7 |
+
|
8 |
+
def get_prompt_generator(agent_type: AgentType = AgentType.DEFAULT, **kwargs):
|
9 |
+
if AgentType.DEFAULT == agent_type or agent_type == AgentType.MS_AGENT:
|
10 |
+
return MSPromptGenerator(**kwargs)
|
11 |
+
elif AgentType.MRKL == agent_type:
|
12 |
+
return MrklPromptGenerator(**kwargs)
|
13 |
+
elif AgentType.Messages == agent_type:
|
14 |
+
return MessagesGenerator(**kwargs)
|
15 |
+
else:
|
16 |
+
raise NotImplementedError
|
my_modelscope_agent/prompt/raw_prompt_builder.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def qwen_chatml_prompt_preprocessor(messages):
|
2 |
+
prompt = ''
|
3 |
+
for message in messages:
|
4 |
+
if message['role'] == 'assistant' and message['content'] == '':
|
5 |
+
prompt += '<|im_start|>assistant\n'
|
6 |
+
else:
|
7 |
+
prompt = prompt + '<|im_start|>{role}\n{content}<|im_end|>\n'.format(
|
8 |
+
role=message['role'],
|
9 |
+
content=message['content'].lstrip('\n').rstrip())
|
10 |
+
|
11 |
+
# in the case of the assistant message is not in the last one, such as function result
|
12 |
+
if messages[-1]['role'] == 'assistant':
|
13 |
+
last_assistant_message_list = messages[-1]['content'].split('\n')
|
14 |
+
if last_assistant_message_list[-1] == '':
|
15 |
+
last_assistant_message_list = last_assistant_message_list[:-1]
|
16 |
+
if len(last_assistant_message_list) == 0:
|
17 |
+
return prompt
|
18 |
+
else:
|
19 |
+
item_length = len('<|im_end|>\n')
|
20 |
+
prompt = prompt[:-item_length]
|
21 |
+
|
22 |
+
return prompt
|
23 |
+
|
24 |
+
|
25 |
+
def plate_preprocessor(messages):
|
26 |
+
return qwen_chatml_prompt_preprocessor(messages)
|
27 |
+
|
28 |
+
|
29 |
+
def build_raw_prompt(model):
|
30 |
+
if isinstance(model, str) or hasattr(model, '__name__'):
|
31 |
+
if model.startswith('qwen'):
|
32 |
+
return qwen_chatml_prompt_preprocessor
|
33 |
+
else:
|
34 |
+
return plate_preprocessor
|
my_modelscope_agent/retrieve.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, Iterable, List, Union
|
3 |
+
|
4 |
+
import json
|
5 |
+
from langchain.document_loaders import (PyPDFLoader, TextLoader,
|
6 |
+
UnstructuredFileLoader)
|
7 |
+
from langchain.embeddings import ModelScopeEmbeddings
|
8 |
+
from langchain.embeddings.base import Embeddings
|
9 |
+
from langchain.schema import Document
|
10 |
+
from langchain.text_splitter import CharacterTextSplitter
|
11 |
+
from langchain.vectorstores import FAISS, VectorStore
|
12 |
+
|
13 |
+
|
14 |
+
class Retrieval:
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
embedding: Embeddings = None,
|
18 |
+
vs_cls: VectorStore = None,
|
19 |
+
top_k: int = 5,
|
20 |
+
vs_params: Dict = {}):
|
21 |
+
self.embedding = embedding or ModelScopeEmbeddings(
|
22 |
+
model_id='damo/nlp_gte_sentence-embedding_chinese-base')
|
23 |
+
self.top_k = top_k
|
24 |
+
self.vs_cls = vs_cls or FAISS
|
25 |
+
self.vs_params = vs_params
|
26 |
+
self.vs = None
|
27 |
+
|
28 |
+
def construct(self, docs):
|
29 |
+
assert len(docs) > 0
|
30 |
+
if isinstance(docs[0], str):
|
31 |
+
self.vs = self.vs_cls.from_texts(docs, self.embedding,
|
32 |
+
**self.vs_params)
|
33 |
+
elif isinstance(docs[0], Document):
|
34 |
+
self.vs = self.vs_cls.from_documents(docs, self.embedding,
|
35 |
+
**self.vs_params)
|
36 |
+
|
37 |
+
def retrieve(self, query: str) -> List[str]:
|
38 |
+
res = self.vs.similarity_search(query, k=self.top_k)
|
39 |
+
if 'page' in res[0].metadata:
|
40 |
+
res.sort(key=lambda doc: doc.metadata['page'])
|
41 |
+
return [r.page_content for r in res]
|
42 |
+
|
43 |
+
|
44 |
+
class ToolRetrieval(Retrieval):
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
embedding: Embeddings = None,
|
48 |
+
vs_cls: VectorStore = None,
|
49 |
+
top_k: int = 5,
|
50 |
+
vs_params: Dict = {}):
|
51 |
+
super().__init__(embedding, vs_cls, top_k, vs_params)
|
52 |
+
|
53 |
+
def retrieve(self, query: str) -> Dict[str, str]:
|
54 |
+
res = self.vs.similarity_search(query, k=self.top_k)
|
55 |
+
|
56 |
+
final_res = {}
|
57 |
+
|
58 |
+
for r in res:
|
59 |
+
content = r.page_content
|
60 |
+
name = json.loads(content)['name']
|
61 |
+
final_res[name] = content
|
62 |
+
|
63 |
+
return final_res
|
64 |
+
|
65 |
+
|
66 |
+
class KnowledgeRetrieval(Retrieval):
|
67 |
+
|
68 |
+
def __init__(self,
|
69 |
+
docs,
|
70 |
+
embedding: Embeddings = None,
|
71 |
+
vs_cls: VectorStore = None,
|
72 |
+
top_k: int = 5,
|
73 |
+
vs_params: Dict = {}):
|
74 |
+
super().__init__(embedding, vs_cls, top_k, vs_params)
|
75 |
+
self.construct(docs)
|
76 |
+
|
77 |
+
@classmethod
|
78 |
+
def from_file(cls,
|
79 |
+
file_path: Union[str, list],
|
80 |
+
embedding: Embeddings = None,
|
81 |
+
vs_cls: VectorStore = None,
|
82 |
+
top_k: int = 5,
|
83 |
+
vs_params: Dict = {}):
|
84 |
+
|
85 |
+
textsplitter = CharacterTextSplitter()
|
86 |
+
all_files = []
|
87 |
+
if isinstance(file_path, str) and os.path.isfile(file_path):
|
88 |
+
all_files.append(file_path)
|
89 |
+
elif isinstance(file_path, list):
|
90 |
+
all_files = file_path
|
91 |
+
elif os.path.isdir(file_path):
|
92 |
+
for root, dirs, files in os.walk(file_path):
|
93 |
+
for f in files:
|
94 |
+
all_files.append(os.path.join(root, f))
|
95 |
+
else:
|
96 |
+
raise ValueError('file_path must be a file or a directory')
|
97 |
+
|
98 |
+
docs = []
|
99 |
+
for f in all_files:
|
100 |
+
if f.lower().endswith('.txt'):
|
101 |
+
loader = TextLoader(f, autodetect_encoding=True)
|
102 |
+
docs += (loader.load_and_split(textsplitter))
|
103 |
+
elif f.lower().endswith('.md'):
|
104 |
+
loader = UnstructuredFileLoader(f, mode='elements')
|
105 |
+
docs += loader.load()
|
106 |
+
elif f.lower().endswith('.pdf'):
|
107 |
+
loader = PyPDFLoader(f)
|
108 |
+
docs += (loader.load_and_split(textsplitter))
|
109 |
+
else:
|
110 |
+
print(f'not support file type: {f}, will be support soon')
|
111 |
+
|
112 |
+
if len(docs) == 0:
|
113 |
+
return None
|
114 |
+
else:
|
115 |
+
return cls(docs, embedding, vs_cls, top_k, vs_params)
|
my_modelscope_agent/tools/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .amap_weather import AMAPWeather
|
2 |
+
from .code_interperter import CodeInterpreter
|
3 |
+
from .code_interpreter_jupyter import CodeInterpreterJupyter
|
4 |
+
from .hf_tool import HFTool
|
5 |
+
from .image_chat_tool import ImageChatTool
|
6 |
+
from .pipeline_tool import ModelscopePipelineTool
|
7 |
+
from .plugin_tool import LangchainTool
|
8 |
+
from .text_address_tool import TextAddressTool
|
9 |
+
from .text_ie_tool import TextInfoExtractTool
|
10 |
+
from .text_ner_tool import TextNerTool
|
11 |
+
from .text_to_image_tool import TextToImageTool
|
12 |
+
from .text_to_speech_tool import TexttoSpeechTool
|
13 |
+
from .text_to_video_tool import TextToVideoTool
|
14 |
+
from .tool import Tool
|
15 |
+
from .translation_en2zh_tool import TranslationEn2ZhTool
|
16 |
+
from .translation_zh2en_tool import TranslationZh2EnTool
|
17 |
+
from .web_browser import WebBrowser
|
18 |
+
from .web_search import WebSearch
|
19 |
+
from .wordart_tool import WordArtTexture
|
20 |
+
|
21 |
+
TOOL_INFO_LIST = {
|
22 |
+
'modelscope_text-translation-zh2en': 'TranslationZh2EnTool',
|
23 |
+
'modelscope_text-translation-en2zh': 'TranslationEn2ZhTool',
|
24 |
+
'modelscope_text-ie': 'TextInfoExtractTool',
|
25 |
+
'modelscope_text-ner': 'TextNerTool',
|
26 |
+
'modelscope_text-address': 'TextAddressTool',
|
27 |
+
'image_gen': 'TextToImageTool',
|
28 |
+
'modelscope_video-generation': 'TextToVideoTool',
|
29 |
+
'modelscope_image-chat': 'ImageChatTool',
|
30 |
+
'modelscope_speech-generation': 'TexttoSpeechTool',
|
31 |
+
'amap_weather': 'AMAPWeather',
|
32 |
+
'code_interpreter': 'CodeInterpreterJupyter',
|
33 |
+
'wordart_texture_generation': 'WordArtTexture',
|
34 |
+
'web_search': 'WebSearch',
|
35 |
+
'web_browser': 'WebBrowser',
|
36 |
+
}
|
my_modelscope_agent/tools/amap_weather.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import requests
|
5 |
+
from ..tools.tool import Tool, ToolSchema
|
6 |
+
from pydantic import ValidationError
|
7 |
+
|
8 |
+
|
9 |
+
class AMAPWeather(Tool):
|
10 |
+
description = '获取对应城市的天气数据'
|
11 |
+
name = 'amap_weather'
|
12 |
+
parameters: list = [{
|
13 |
+
'name': 'location',
|
14 |
+
'description': 'get temperature for a specific location',
|
15 |
+
'required': True
|
16 |
+
}]
|
17 |
+
|
18 |
+
def __init__(self, cfg={}):
|
19 |
+
self.cfg = cfg.get(self.name, {})
|
20 |
+
|
21 |
+
# remote call
|
22 |
+
self.url = 'https://restapi.amap.com/v3/weather/weatherInfo?city={city}&key={key}'
|
23 |
+
self.token = self.cfg.get('token', os.environ.get('AMAP_TOKEN', ''))
|
24 |
+
self.city_df = pd.read_excel(
|
25 |
+
'https://modelscope.oss-cn-beijing.aliyuncs.com/resource/agent/AMap_adcode_citycode.xlsx'
|
26 |
+
)
|
27 |
+
assert self.token != '', 'weather api token must be acquired through ' \
|
28 |
+
'https://lbs.amap.com/api/webservice/guide/create-project/get-key and set by AMAP_TOKEN'
|
29 |
+
|
30 |
+
try:
|
31 |
+
all_param = {
|
32 |
+
'name': self.name,
|
33 |
+
'description': self.description,
|
34 |
+
'parameters': self.parameters
|
35 |
+
}
|
36 |
+
self.tool_schema = ToolSchema(**all_param)
|
37 |
+
except ValidationError:
|
38 |
+
raise ValueError(f'Error when parsing parameters of {self.name}')
|
39 |
+
|
40 |
+
self._str = self.tool_schema.model_dump_json()
|
41 |
+
self._function = self.parse_pydantic_model_to_openai_function(
|
42 |
+
all_param)
|
43 |
+
|
44 |
+
def get_city_adcode(self, city_name):
|
45 |
+
filtered_df = self.city_df[self.city_df['中文名'] == city_name]
|
46 |
+
if len(filtered_df['adcode'].values) == 0:
|
47 |
+
raise ValueError(
|
48 |
+
f'location {city_name} not found, availables are {self.city_df["中文名"]}'
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
return filtered_df['adcode'].values[0]
|
52 |
+
|
53 |
+
def __call__(self, *args, **kwargs):
|
54 |
+
location = kwargs['location']
|
55 |
+
response = requests.get(
|
56 |
+
self.url.format(
|
57 |
+
city=self.get_city_adcode(location), key=self.token))
|
58 |
+
data = response.json()
|
59 |
+
if data['status'] == '0':
|
60 |
+
raise RuntimeError(data)
|
61 |
+
else:
|
62 |
+
weather = data['lives'][0]['weather']
|
63 |
+
temperature = data['lives'][0]['temperature']
|
64 |
+
return {'result': f'{location}的天气是{weather}温度是{temperature}度。'}
|
my_modelscope_agent/tools/code_interperter.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import appdirs
|
6 |
+
import json
|
7 |
+
|
8 |
+
from .code_interpreter_utils.create_code_interpreter import \
|
9 |
+
create_code_interpreter
|
10 |
+
from .code_interpreter_utils.language_map import language_map
|
11 |
+
from .code_interpreter_utils.truncate_output import truncate_output
|
12 |
+
from .tool import Tool
|
13 |
+
|
14 |
+
|
15 |
+
class CodeInterpreter(Tool):
|
16 |
+
"""
|
17 |
+
using open interpreter to interpret code
|
18 |
+
by https://github.com/KillianLucas/open-interpreter
|
19 |
+
"""
|
20 |
+
description = 'Executes code on the user\'s machine, **in the users local environment**, and returns the output'
|
21 |
+
name = 'code_interpreter'
|
22 |
+
parameters: list = [{
|
23 |
+
'name': 'language',
|
24 |
+
'description':
|
25 |
+
'The programming language (required parameter to the `execute` function)',
|
26 |
+
'required': True
|
27 |
+
}, {
|
28 |
+
'name': 'code',
|
29 |
+
'description': 'The code to execute (required)',
|
30 |
+
'required': True
|
31 |
+
}]
|
32 |
+
|
33 |
+
def __init__(self, cfg={}):
|
34 |
+
super().__init__(cfg)
|
35 |
+
self.create_code_interpreter = create_code_interpreter
|
36 |
+
self.language_map = language_map
|
37 |
+
self.truncate_output = truncate_output
|
38 |
+
|
39 |
+
self._code_interpreters = {}
|
40 |
+
self.max_output = self.cfg.get('max_output', 2000)
|
41 |
+
|
42 |
+
def _local_call(self, *args, **kwargs):
|
43 |
+
|
44 |
+
language, code = self._handle_input_fallback(**kwargs)
|
45 |
+
|
46 |
+
try:
|
47 |
+
# Fix a common error where the LLM thinks it's in a Jupyter notebook
|
48 |
+
if language == 'python' and code.startswith('!'):
|
49 |
+
code = code[1:]
|
50 |
+
language = 'shell'
|
51 |
+
|
52 |
+
if language in self.language_map:
|
53 |
+
if language not in self._code_interpreters:
|
54 |
+
self._code_interpreters[
|
55 |
+
language] = self.create_code_interpreter(language)
|
56 |
+
code_interpreter = self._code_interpreters[language]
|
57 |
+
else:
|
58 |
+
# This still prints code but don't allow code to run. Let Open-Interpreter know through output message
|
59 |
+
error_output = f'Error: Open Interpreter does not currently support {language}.'
|
60 |
+
print(error_output)
|
61 |
+
output = '\n' + error_output
|
62 |
+
return {'result': output.strip()}
|
63 |
+
|
64 |
+
output = ''
|
65 |
+
for line in code_interpreter.run(code):
|
66 |
+
if 'output' in line:
|
67 |
+
output += '\n' + line['output']
|
68 |
+
|
69 |
+
# Truncate output
|
70 |
+
output = self.truncate_output(output, self.max_output)
|
71 |
+
except Exception as e:
|
72 |
+
error = traceback.format_exc()
|
73 |
+
output = ' '.join(f'{key}:{value}'
|
74 |
+
for key, value in kwargs.items())
|
75 |
+
output += f'\nDetail error is {e}.\n{error}'
|
76 |
+
|
77 |
+
return {'result': output.strip()}
|
78 |
+
|
79 |
+
def _handle_input_fallback(self, **kwargs):
|
80 |
+
"""
|
81 |
+
an alternative method is to parse code in content not from function call
|
82 |
+
such as:
|
83 |
+
text = response['content']
|
84 |
+
code_block = re.search(r'```([\s\S]+)```', text) # noqa W^05
|
85 |
+
if code_block:
|
86 |
+
result = code_block.group(1)
|
87 |
+
language = result.split('\n')[0]
|
88 |
+
code = '\n'.join(result.split('\n')[1:])
|
89 |
+
|
90 |
+
:param fallback_text:
|
91 |
+
:return: language, cocde
|
92 |
+
"""
|
93 |
+
|
94 |
+
language = kwargs.get('language', None)
|
95 |
+
code = kwargs.get('code', None)
|
96 |
+
fallback = kwargs.get('fallback', None)
|
97 |
+
|
98 |
+
if language and code:
|
99 |
+
return language, code
|
100 |
+
elif fallback:
|
101 |
+
try:
|
102 |
+
text = fallback
|
103 |
+
code_block = re.search(r'```([\s\S]+)```', text) # noqa W^05
|
104 |
+
if code_block:
|
105 |
+
result = code_block.group(1)
|
106 |
+
# for multi code_block
|
107 |
+
result = result.split('```')[0]
|
108 |
+
language = result.split('\n')[0]
|
109 |
+
if language == 'py' or language == 'python':
|
110 |
+
# handle py case
|
111 |
+
# ```py code ```
|
112 |
+
language = 'python'
|
113 |
+
code = '\n'.join(result.split('\n')[1:])
|
114 |
+
return language, code
|
115 |
+
|
116 |
+
if language == 'json':
|
117 |
+
# handle json case
|
118 |
+
# ```json {language,code}```
|
119 |
+
parameters = json.loads('\n'.join(
|
120 |
+
result.split('\n')[1:]).replace('\n', ''))
|
121 |
+
return parameters['language'], parameters['code']
|
122 |
+
except ValueError:
|
123 |
+
return language, code
|
124 |
+
else:
|
125 |
+
return language, code
|
my_modelscope_agent/tools/code_interpreter_jupyter.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import atexit
|
3 |
+
import base64
|
4 |
+
import glob
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import queue
|
8 |
+
import re
|
9 |
+
import shutil
|
10 |
+
import signal
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
import traceback
|
15 |
+
import uuid
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Dict, Optional
|
18 |
+
|
19 |
+
import json
|
20 |
+
import matplotlib
|
21 |
+
import PIL.Image
|
22 |
+
from jupyter_client import BlockingKernelClient
|
23 |
+
|
24 |
+
from .tool import Tool
|
25 |
+
|
26 |
+
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/ci_workspace')
|
27 |
+
|
28 |
+
STATIC_URL = os.getenv('CODE_INTERPRETER_STATIC_URL',
|
29 |
+
'http://127.0.0.1:7866/static')
|
30 |
+
|
31 |
+
LAUNCH_KERNEL_PY = """
|
32 |
+
from ipykernel import kernelapp as app
|
33 |
+
app.launch_new_instance()
|
34 |
+
"""
|
35 |
+
|
36 |
+
INIT_CODE_FILE = str(
|
37 |
+
Path(__file__).absolute().parent / 'code_interpreter_utils'
|
38 |
+
/ 'code_interpreter_init_kernel.py')
|
39 |
+
|
40 |
+
ALIB_FONT_FILE = str(
|
41 |
+
Path(__file__).absolute().parent / 'code_interpreter_utils'
|
42 |
+
/ 'AlibabaPuHuiTi-3-45-Light.ttf')
|
43 |
+
|
44 |
+
_KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {}
|
45 |
+
|
46 |
+
|
47 |
+
class CodeInterpreterJupyter(Tool):
|
48 |
+
"""
|
49 |
+
using jupyter kernel client to interpret python code,
|
50 |
+
should not be used the other code interpreter tool at the same time
|
51 |
+
"""
|
52 |
+
description = '代码解释器,可用于执行Python代码。'
|
53 |
+
name = 'code_interpreter'
|
54 |
+
parameters: list = [{
|
55 |
+
'name': 'code',
|
56 |
+
'description': '待执行的代码',
|
57 |
+
'required': True
|
58 |
+
}]
|
59 |
+
|
60 |
+
def __init__(self, cfg={}):
|
61 |
+
super().__init__(cfg)
|
62 |
+
self.timeout = self.cfg.get('timeout', 30)
|
63 |
+
self.image_server = self.cfg.get('image_server', False)
|
64 |
+
self.kernel_clients: Dict[int, BlockingKernelClient] = {}
|
65 |
+
atexit.register(self._kill_kernels)
|
66 |
+
|
67 |
+
pid: int = os.getpid()
|
68 |
+
if pid in self.kernel_clients:
|
69 |
+
kc = self.kernel_clients[pid]
|
70 |
+
else:
|
71 |
+
self._fix_matplotlib_cjk_font_issue()
|
72 |
+
kc = self._start_kernel(pid)
|
73 |
+
with open(INIT_CODE_FILE) as fin:
|
74 |
+
start_code = fin.read()
|
75 |
+
start_code = start_code.replace('{{M6_FONT_PATH}}',
|
76 |
+
repr(ALIB_FONT_FILE)[1:-1])
|
77 |
+
print(self._execute_code(kc, start_code))
|
78 |
+
self.kernel_clients[pid] = kc
|
79 |
+
|
80 |
+
self.kc = kc
|
81 |
+
|
82 |
+
def __del__(self):
|
83 |
+
# make sure all the kernels are killed during __del__
|
84 |
+
signal.signal(signal.SIGTERM, self._kill_kernels)
|
85 |
+
signal.signal(signal.SIGINT, self._kill_kernels)
|
86 |
+
|
87 |
+
def _start_kernel(self, pid) -> BlockingKernelClient:
|
88 |
+
connection_file = os.path.join(WORK_DIR,
|
89 |
+
f'kernel_connection_file_{pid}.json')
|
90 |
+
launch_kernel_script = os.path.join(WORK_DIR,
|
91 |
+
f'launch_kernel_{pid}.py')
|
92 |
+
for f in [connection_file, launch_kernel_script]:
|
93 |
+
if os.path.exists(f):
|
94 |
+
print(f'WARNING: {f} already exists')
|
95 |
+
os.remove(f)
|
96 |
+
|
97 |
+
os.makedirs(WORK_DIR, exist_ok=True)
|
98 |
+
|
99 |
+
with open(launch_kernel_script, 'w') as fout:
|
100 |
+
fout.write(LAUNCH_KERNEL_PY)
|
101 |
+
|
102 |
+
available_envs = ['PATH', 'PYTHONPATH', 'LD_LIBRARY_PATH']
|
103 |
+
envs = {}
|
104 |
+
for k in available_envs:
|
105 |
+
if os.getenv(k) is not None:
|
106 |
+
envs[k] = os.getenv(k)
|
107 |
+
|
108 |
+
args = (
|
109 |
+
sys.executable,
|
110 |
+
launch_kernel_script,
|
111 |
+
'--IPKernelApp.connection_file',
|
112 |
+
connection_file,
|
113 |
+
'--matplotlib=inline',
|
114 |
+
'--quiet',
|
115 |
+
)
|
116 |
+
kernel_process = subprocess.Popen([*args], env=envs,
|
117 |
+
cwd=WORK_DIR) # noqa E126
|
118 |
+
print(f"INFO: kernel process's PID = {kernel_process.pid}")
|
119 |
+
|
120 |
+
# Wait for kernel connection file to be written
|
121 |
+
while True:
|
122 |
+
if not os.path.isfile(connection_file):
|
123 |
+
time.sleep(0.1)
|
124 |
+
else:
|
125 |
+
# Keep looping if JSON parsing fails, file may be partially written
|
126 |
+
try:
|
127 |
+
with open(connection_file, 'r') as fp:
|
128 |
+
json.load(fp)
|
129 |
+
break
|
130 |
+
except json.JSONDecodeError:
|
131 |
+
pass
|
132 |
+
|
133 |
+
# Client
|
134 |
+
kc = BlockingKernelClient(connection_file=connection_file)
|
135 |
+
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
136 |
+
kc.load_connection_file()
|
137 |
+
kc.start_channels()
|
138 |
+
kc.wait_for_ready()
|
139 |
+
return kc
|
140 |
+
|
141 |
+
def _kill_kernels(self):
|
142 |
+
for v in self.kernel_clients.values():
|
143 |
+
v.shutdown()
|
144 |
+
for k in list(self.kernel_clients.keys()):
|
145 |
+
del self.kernel_clients[k]
|
146 |
+
|
147 |
+
def _serve_image(self, image_base64: str, image_type: str) -> str:
|
148 |
+
image_file = f'{uuid.uuid4()}.{image_type}'
|
149 |
+
local_image_file = os.path.join(WORK_DIR, image_file)
|
150 |
+
|
151 |
+
png_bytes = base64.b64decode(image_base64)
|
152 |
+
assert isinstance(png_bytes, bytes)
|
153 |
+
|
154 |
+
if image_type == 'gif':
|
155 |
+
with open(local_image_file, 'wb') as file:
|
156 |
+
file.write(png_bytes)
|
157 |
+
else:
|
158 |
+
bytes_io = io.BytesIO(png_bytes)
|
159 |
+
PIL.Image.open(bytes_io).save(local_image_file, image_type)
|
160 |
+
|
161 |
+
if self.image_server:
|
162 |
+
image_url = f'{STATIC_URL}/{image_file}'
|
163 |
+
return image_url
|
164 |
+
else:
|
165 |
+
return local_image_file
|
166 |
+
|
167 |
+
def _escape_ansi(self, line: str) -> str:
|
168 |
+
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
169 |
+
return ansi_escape.sub('', line)
|
170 |
+
|
171 |
+
def _fix_matplotlib_cjk_font_issue(self):
|
172 |
+
ttf_name = os.path.basename(ALIB_FONT_FILE)
|
173 |
+
local_ttf = os.path.join(
|
174 |
+
os.path.abspath(
|
175 |
+
os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
|
176 |
+
'fonts', 'ttf', ttf_name)
|
177 |
+
if not os.path.exists(local_ttf):
|
178 |
+
try:
|
179 |
+
shutil.copy(ALIB_FONT_FILE, local_ttf)
|
180 |
+
font_list_cache = os.path.join(matplotlib.get_cachedir(),
|
181 |
+
'fontlist-*.json')
|
182 |
+
for cache_file in glob.glob(font_list_cache):
|
183 |
+
with open(cache_file) as fin:
|
184 |
+
cache_content = fin.read()
|
185 |
+
if ttf_name not in cache_content:
|
186 |
+
os.remove(cache_file)
|
187 |
+
except Exception:
|
188 |
+
traceback.format_exc()
|
189 |
+
|
190 |
+
def _execute_code(self, kc: BlockingKernelClient, code: str) -> str:
|
191 |
+
kc.wait_for_ready()
|
192 |
+
kc.execute(code)
|
193 |
+
result = ''
|
194 |
+
image_idx = 0
|
195 |
+
while True:
|
196 |
+
text = ''
|
197 |
+
image = ''
|
198 |
+
finished = False
|
199 |
+
msg_type = 'error'
|
200 |
+
try:
|
201 |
+
msg = kc.get_iopub_msg()
|
202 |
+
msg_type = msg['msg_type']
|
203 |
+
if msg_type == 'status':
|
204 |
+
if msg['content'].get('execution_state') == 'idle':
|
205 |
+
finished = True
|
206 |
+
elif msg_type == 'execute_result':
|
207 |
+
text = msg['content']['data'].get('text/plain', '')
|
208 |
+
if 'image/png' in msg['content']['data']:
|
209 |
+
image_b64 = msg['content']['data']['image/png']
|
210 |
+
image_url = self._serve_image(image_b64, 'png')
|
211 |
+
image_idx += 1
|
212 |
+
image = '![IMAGEGEN](%s)' % (image_url)
|
213 |
+
elif 'text/html' in msg['content']['data']:
|
214 |
+
text += '\n' + msg['content']['data']['text/html']
|
215 |
+
elif 'image/gif' in msg['content']['data']:
|
216 |
+
image_b64 = msg['content']['data']['image/gif']
|
217 |
+
image_url = self._serve_image(image_b64, 'gif')
|
218 |
+
image_idx += 1
|
219 |
+
image = '![IMAGEGEN](%s)' % (image_url)
|
220 |
+
elif msg_type == 'display_data':
|
221 |
+
if 'image/png' in msg['content']['data']:
|
222 |
+
image_b64 = msg['content']['data']['image/png']
|
223 |
+
image_url = self._serve_image(image_b64, 'png')
|
224 |
+
image_idx += 1
|
225 |
+
image = '![IMAGEGEN](%s)' % (image_url)
|
226 |
+
else:
|
227 |
+
text = msg['content']['data'].get('text/plain', '')
|
228 |
+
elif msg_type == 'stream':
|
229 |
+
msg_type = msg['content']['name'] # stdout, stderr
|
230 |
+
text = msg['content']['text']
|
231 |
+
elif msg_type == 'error':
|
232 |
+
text = self._escape_ansi('\n'.join(
|
233 |
+
msg['content']['traceback']))
|
234 |
+
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
|
235 |
+
text = 'Timeout: Code execution exceeded the time limit.'
|
236 |
+
except queue.Empty:
|
237 |
+
text = 'Timeout: Code execution exceeded the time limit.'
|
238 |
+
finished = True
|
239 |
+
except Exception:
|
240 |
+
text = 'The code interpreter encountered an unexpected error.'
|
241 |
+
traceback.format_exc()
|
242 |
+
finished = True
|
243 |
+
if text:
|
244 |
+
result += f'\n{text}'
|
245 |
+
if image:
|
246 |
+
result += f'\n\n{image}'
|
247 |
+
if finished:
|
248 |
+
break
|
249 |
+
result = result.lstrip('\n')
|
250 |
+
if not result:
|
251 |
+
result += 'The code executed successfully.'
|
252 |
+
return result
|
253 |
+
|
254 |
+
def _local_call(self, *args, **kwargs):
|
255 |
+
code = self._handle_input_fallback(**kwargs)
|
256 |
+
if not code.strip():
|
257 |
+
return ''
|
258 |
+
|
259 |
+
if self.timeout:
|
260 |
+
code = f'_M6CountdownTimer.start({self.timeout})\n{code}'
|
261 |
+
|
262 |
+
fixed_code = []
|
263 |
+
for line in code.split('\n'):
|
264 |
+
fixed_code.append(line)
|
265 |
+
if line.startswith('sns.set_theme('):
|
266 |
+
fixed_code.append(
|
267 |
+
'plt.rcParams["font.family"] = _m6_font_prop.get_name()')
|
268 |
+
fixed_code = '\n'.join(fixed_code)
|
269 |
+
result = self._execute_code(self.kc, fixed_code)
|
270 |
+
|
271 |
+
if self.timeout:
|
272 |
+
self._execute_code(self.kc, '_M6CountdownTimer.cancel()')
|
273 |
+
|
274 |
+
return {'result': result}
|
275 |
+
|
276 |
+
def _handle_input_fallback(self, **kwargs):
|
277 |
+
"""
|
278 |
+
an alternative method is to parse code in content not from function call
|
279 |
+
such as:
|
280 |
+
text = response['content']
|
281 |
+
code_block = re.search(r'```([\s\S]+)```', text) # noqa W^05
|
282 |
+
if code_block:
|
283 |
+
result = code_block.group(1)
|
284 |
+
language = result.split('\n')[0]
|
285 |
+
code = '\n'.join(result.split('\n')[1:])
|
286 |
+
|
287 |
+
:param fallback_text:
|
288 |
+
:return: language, cocde
|
289 |
+
"""
|
290 |
+
|
291 |
+
code = kwargs.get('code', None)
|
292 |
+
fallback = kwargs.get('fallback', None)
|
293 |
+
|
294 |
+
if code:
|
295 |
+
return code
|
296 |
+
elif fallback:
|
297 |
+
try:
|
298 |
+
text = fallback
|
299 |
+
code_block = re.search(r'```([\s\S]+)```', text) # noqa W^05
|
300 |
+
if code_block:
|
301 |
+
result = code_block.group(1)
|
302 |
+
language = result.split('\n')[0]
|
303 |
+
if language == 'py' or language == 'python':
|
304 |
+
# handle py case
|
305 |
+
# ```py code ```
|
306 |
+
language = 'python'
|
307 |
+
code = '\n'.join(result.split('\n')[1:])
|
308 |
+
return code
|
309 |
+
|
310 |
+
if language == 'json':
|
311 |
+
# handle json case
|
312 |
+
# ```json {language,code}```
|
313 |
+
parameters = json.loads('\n'.join(
|
314 |
+
result.split('\n')[1:]).replace('\n', ''))
|
315 |
+
return parameters['code']
|
316 |
+
except ValueError:
|
317 |
+
return code
|
318 |
+
else:
|
319 |
+
return code
|
my_modelscope_agent/tools/code_interpreter_utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# all the utility functions under code_interpreter_utils are borrowed from project
|
2 |
+
# in order to use python lower than 3.10
|
3 |
+
# https://github.com/KillianLucas/open-interpreter
|
4 |
+
|
5 |
+
from .base_code_interpreter import BaseCodeInterpreter
|
my_modelscope_agent/tools/code_interpreter_utils/base_code_interpreter.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseCodeInterpreter:
|
2 |
+
"""
|
3 |
+
.run is a generator that yields a dict with attributes: active_line, output
|
4 |
+
"""
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def run(self, code):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def terminate(self):
|
13 |
+
pass
|
my_modelscope_agent/tools/code_interpreter_utils/code_interpreter_init_kernel.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math # noqa
|
2 |
+
import os # noqa
|
3 |
+
import re # noqa
|
4 |
+
import signal
|
5 |
+
|
6 |
+
import json # noqa
|
7 |
+
import matplotlib # noqa
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np # noqa
|
10 |
+
import pandas as pd # noqa
|
11 |
+
import seaborn as sns
|
12 |
+
from matplotlib.font_manager import FontProperties
|
13 |
+
from sympy import Eq, solve, symbols # noqa
|
14 |
+
|
15 |
+
|
16 |
+
def input(*args, **kwargs): # noqa
|
17 |
+
raise NotImplementedError('Python input() function is disabled.')
|
18 |
+
|
19 |
+
|
20 |
+
def _m6_timout_handler(_signum=None, _frame=None):
|
21 |
+
raise TimeoutError('M6_CODE_INTERPRETER_TIMEOUT')
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
signal.signal(signal.SIGALRM, _m6_timout_handler)
|
26 |
+
except AttributeError: # windows
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
class _M6CountdownTimer:
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def start(cls, timeout: int):
|
34 |
+
try:
|
35 |
+
signal.alarm(timeout)
|
36 |
+
except AttributeError: # windows
|
37 |
+
pass # TODO: I haven't found a solution that works with jupyter yet.
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def cancel(cls):
|
41 |
+
try:
|
42 |
+
signal.alarm(0)
|
43 |
+
except AttributeError: # windows
|
44 |
+
pass # TODO
|
45 |
+
|
46 |
+
|
47 |
+
sns.set_theme()
|
48 |
+
|
49 |
+
_m6_font_prop = FontProperties(fname='{{M6_FONT_PATH}}')
|
50 |
+
plt.rcParams['font.family'] = _m6_font_prop.get_name()
|
my_modelscope_agent/tools/code_interpreter_utils/create_code_interpreter.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .language_map import language_map
|
2 |
+
|
3 |
+
|
4 |
+
def create_code_interpreter(language):
|
5 |
+
# Case in-sensitive
|
6 |
+
language = language.lower()
|
7 |
+
|
8 |
+
try:
|
9 |
+
CodeInterpreter = language_map[language]
|
10 |
+
return CodeInterpreter()
|
11 |
+
except KeyError:
|
12 |
+
raise ValueError(f'Unknown or unsupported language: {language}')
|
my_modelscope_agent/tools/code_interpreter_utils/language_map.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .languages.applescript import AppleScript
|
2 |
+
from .languages.html import HTML
|
3 |
+
from .languages.javascript import JavaScript
|
4 |
+
from .languages.powershell import PowerShell
|
5 |
+
from .languages.python import Python
|
6 |
+
from .languages.r import R
|
7 |
+
from .languages.shell import Shell
|
8 |
+
|
9 |
+
language_map = {
|
10 |
+
'python': Python,
|
11 |
+
'bash': Shell,
|
12 |
+
'shell': Shell,
|
13 |
+
'zsh': Shell,
|
14 |
+
'javascript': JavaScript,
|
15 |
+
'html': HTML,
|
16 |
+
'applescript': AppleScript,
|
17 |
+
'r': R,
|
18 |
+
'powershell': PowerShell,
|
19 |
+
}
|
my_modelscope_agent/tools/code_interpreter_utils/languages/__init__.py
ADDED
File without changes
|
my_modelscope_agent/tools/code_interpreter_utils/languages/applescript.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from ..subprocess_code_interpreter import SubprocessCodeInterpreter
|
4 |
+
|
5 |
+
|
6 |
+
class AppleScript(SubprocessCodeInterpreter):
|
7 |
+
file_extension = 'applescript'
|
8 |
+
proper_name = 'AppleScript'
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.start_cmd = os.environ.get('SHELL', '/bin/zsh')
|
13 |
+
|
14 |
+
def preprocess_code(self, code):
|
15 |
+
"""
|
16 |
+
Inserts an end_of_execution marker and adds active line indicators.
|
17 |
+
"""
|
18 |
+
# Add active line indicators to the code
|
19 |
+
code = self.add_active_line_indicators(code)
|
20 |
+
|
21 |
+
# Escape double quotes
|
22 |
+
code = code.replace('"', r"\"")
|
23 |
+
|
24 |
+
# Wrap in double quotes
|
25 |
+
code = '"' + code + '"'
|
26 |
+
|
27 |
+
# Prepend start command for AppleScript
|
28 |
+
code = 'osascript -e ' + code
|
29 |
+
|
30 |
+
# Append end of execution indicator
|
31 |
+
code += '; echo "##end_of_execution##"'
|
32 |
+
|
33 |
+
return code
|
34 |
+
|
35 |
+
def add_active_line_indicators(self, code):
|
36 |
+
"""
|
37 |
+
Adds log commands to indicate the active line of execution in the AppleScript.
|
38 |
+
"""
|
39 |
+
modified_lines = []
|
40 |
+
lines = code.split('\n')
|
41 |
+
|
42 |
+
for idx, line in enumerate(lines):
|
43 |
+
# Add log command to indicate the line number
|
44 |
+
if line.strip(): # Only add if line is not empty
|
45 |
+
modified_lines.append(f'log "##active_line{idx + 1}##"')
|
46 |
+
modified_lines.append(line)
|
47 |
+
|
48 |
+
return '\n'.join(modified_lines)
|
49 |
+
|
50 |
+
def detect_active_line(self, line):
|
51 |
+
"""
|
52 |
+
Detects active line indicator in the output.
|
53 |
+
"""
|
54 |
+
prefix = '##active_line'
|
55 |
+
if prefix in line:
|
56 |
+
try:
|
57 |
+
return int(line.split(prefix)[1].split()[0])
|
58 |
+
except Exception as e:
|
59 |
+
print(e)
|
60 |
+
pass
|
61 |
+
return None
|
62 |
+
|
63 |
+
def detect_end_of_execution(self, line):
|
64 |
+
"""
|
65 |
+
Detects end of execution marker in the output.
|
66 |
+
"""
|
67 |
+
return '##end_of_execution##' in line
|
my_modelscope_agent/tools/code_interpreter_utils/languages/html.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import webbrowser
|
4 |
+
|
5 |
+
from ..base_code_interpreter import BaseCodeInterpreter
|
6 |
+
|
7 |
+
|
8 |
+
class HTML(BaseCodeInterpreter):
|
9 |
+
file_extension = 'html'
|
10 |
+
proper_name = 'HTML'
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
def run(self, code):
|
16 |
+
# Create a temporary HTML file with the content
|
17 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as f:
|
18 |
+
f.write(code.encode())
|
19 |
+
|
20 |
+
# Open the HTML file with the default web browser
|
21 |
+
webbrowser.open('file://' + os.path.realpath(f.name))
|
22 |
+
|
23 |
+
yield {
|
24 |
+
'output':
|
25 |
+
f"Saved to {os.path.realpath(f.name)} and opened with the user's default web browser."
|
26 |
+
}
|
my_modelscope_agent/tools/code_interpreter_utils/languages/javascript.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from ..subprocess_code_interpreter import SubprocessCodeInterpreter
|
4 |
+
|
5 |
+
|
6 |
+
class JavaScript(SubprocessCodeInterpreter):
|
7 |
+
file_extension = 'js'
|
8 |
+
proper_name = 'JavaScript'
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.start_cmd = 'node -i'
|
13 |
+
|
14 |
+
def preprocess_code(self, code):
|
15 |
+
return preprocess_javascript(code)
|
16 |
+
|
17 |
+
def line_postprocessor(self, line):
|
18 |
+
# Node's interactive REPL outputs a billion things
|
19 |
+
# So we clean it up:
|
20 |
+
if 'Welcome to Node.js' in line:
|
21 |
+
return None
|
22 |
+
if line.strip() in ['undefined', 'Type ".help" for more information.']:
|
23 |
+
return None
|
24 |
+
# Remove trailing ">"s
|
25 |
+
line = re.sub(r'^\s*(>\s*)+', '', line)
|
26 |
+
return line
|
27 |
+
|
28 |
+
def detect_active_line(self, line):
|
29 |
+
if '##active_line' in line:
|
30 |
+
return int(line.split('##active_line')[1].split('##')[0])
|
31 |
+
return None
|
32 |
+
|
33 |
+
def detect_end_of_execution(self, line):
|
34 |
+
return '##end_of_execution##' in line
|
35 |
+
|
36 |
+
|
37 |
+
def preprocess_javascript(code):
|
38 |
+
"""
|
39 |
+
Add active line markers
|
40 |
+
Wrap in a try catch
|
41 |
+
Add end of execution marker
|
42 |
+
"""
|
43 |
+
|
44 |
+
# Split code into lines
|
45 |
+
lines = code.split('\n')
|
46 |
+
processed_lines = []
|
47 |
+
|
48 |
+
for i, line in enumerate(lines, 1):
|
49 |
+
# Add active line print
|
50 |
+
processed_lines.append(f'console.log("##active_line{i}##");')
|
51 |
+
processed_lines.append(line)
|
52 |
+
|
53 |
+
# Join lines to form the processed code
|
54 |
+
processed_code = '\n'.join(processed_lines)
|
55 |
+
|
56 |
+
# Wrap in a try-catch and add end of execution marker
|
57 |
+
processed_code = f"""
|
58 |
+
try {{
|
59 |
+
{processed_code}
|
60 |
+
}} catch (e) {{
|
61 |
+
console.log(e);
|
62 |
+
}}
|
63 |
+
console.log("##end_of_execution##");
|
64 |
+
"""
|
65 |
+
|
66 |
+
return processed_code
|
my_modelscope_agent/tools/code_interpreter_utils/languages/powershell.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
from ..subprocess_code_interpreter import SubprocessCodeInterpreter
|
6 |
+
|
7 |
+
|
8 |
+
class PowerShell(SubprocessCodeInterpreter):
|
9 |
+
file_extension = 'ps1'
|
10 |
+
proper_name = 'PowerShell'
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# Determine the start command based on the platform (use "powershell" for Windows)
|
16 |
+
if platform.system() == 'Windows':
|
17 |
+
self.start_cmd = 'powershell.exe'
|
18 |
+
# self.start_cmd = os.environ.get('SHELL', 'powershell.exe')
|
19 |
+
else:
|
20 |
+
# On non-Windows platforms, prefer pwsh (PowerShell Core) if available, or fall back to bash
|
21 |
+
self.start_cmd = 'pwsh' if shutil.which('pwsh') else 'bash'
|
22 |
+
|
23 |
+
def preprocess_code(self, code):
|
24 |
+
return preprocess_powershell(code)
|
25 |
+
|
26 |
+
def line_postprocessor(self, line):
|
27 |
+
return line
|
28 |
+
|
29 |
+
def detect_active_line(self, line):
|
30 |
+
if '##active_line' in line:
|
31 |
+
return int(line.split('##active_line')[1].split('##')[0])
|
32 |
+
return None
|
33 |
+
|
34 |
+
def detect_end_of_execution(self, line):
|
35 |
+
return '##end_of_execution##' in line
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_powershell(code):
|
39 |
+
"""
|
40 |
+
Add active line markers
|
41 |
+
Wrap in try-catch block
|
42 |
+
Add end of execution marker
|
43 |
+
"""
|
44 |
+
# Add commands that tell us what the active line is
|
45 |
+
code = add_active_line_prints(code)
|
46 |
+
|
47 |
+
# Wrap in try-catch block for error handling
|
48 |
+
code = wrap_in_try_catch(code)
|
49 |
+
|
50 |
+
# Add end marker (we'll be listening for this to know when it ends)
|
51 |
+
code += '\nWrite-Output "##end_of_execution##"'
|
52 |
+
|
53 |
+
return code
|
54 |
+
|
55 |
+
|
56 |
+
def add_active_line_prints(code):
|
57 |
+
"""
|
58 |
+
Add Write-Output statements indicating line numbers to a PowerShell script.
|
59 |
+
"""
|
60 |
+
lines = code.split('\n')
|
61 |
+
for index, line in enumerate(lines):
|
62 |
+
# Insert the Write-Output command before the actual line
|
63 |
+
lines[index] = f'Write-Output "##active_line{index + 1}##"\n{line}'
|
64 |
+
return '\n'.join(lines)
|
65 |
+
|
66 |
+
|
67 |
+
def wrap_in_try_catch(code):
|
68 |
+
"""
|
69 |
+
Wrap PowerShell code in a try-catch block to catch errors and display them.
|
70 |
+
"""
|
71 |
+
try_catch_code = """
|
72 |
+
try {
|
73 |
+
$ErrorActionPreference = "Stop"
|
74 |
+
"""
|
75 |
+
return try_catch_code + code + '\n} catch {\n Write-Error $_\n}\n'
|
my_modelscope_agent/tools/code_interpreter_utils/languages/python.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import shlex
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from ..subprocess_code_interpreter import SubprocessCodeInterpreter
|
8 |
+
|
9 |
+
|
10 |
+
class Python(SubprocessCodeInterpreter):
|
11 |
+
file_extension = 'py'
|
12 |
+
proper_name = 'Python'
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
executable = sys.executable
|
17 |
+
if os.name != 'nt': # not Windows
|
18 |
+
executable = shlex.quote(executable)
|
19 |
+
self.start_cmd = executable + ' -i -q -u'
|
20 |
+
|
21 |
+
def preprocess_code(self, code):
|
22 |
+
return preprocess_python(code)
|
23 |
+
|
24 |
+
def line_postprocessor(self, line):
|
25 |
+
if re.match(r'^(\s*>>>\s*|\s*\.\.\.\s*)', line):
|
26 |
+
return None
|
27 |
+
return line
|
28 |
+
|
29 |
+
def detect_active_line(self, line):
|
30 |
+
if '##active_line' in line:
|
31 |
+
return int(line.split('##active_line')[1].split('##')[0])
|
32 |
+
return None
|
33 |
+
|
34 |
+
def detect_end_of_execution(self, line):
|
35 |
+
return '##end_of_execution##' in line
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_python(code):
|
39 |
+
"""
|
40 |
+
Add active line markers
|
41 |
+
Wrap in a try except
|
42 |
+
Add end of execution marker
|
43 |
+
"""
|
44 |
+
|
45 |
+
# Add print commands that tell us what the active line is
|
46 |
+
code = add_active_line_prints(code)
|
47 |
+
|
48 |
+
# Wrap in a try except
|
49 |
+
code = wrap_in_try_except(code)
|
50 |
+
|
51 |
+
# Remove any whitespace lines, as this will break indented blocks
|
52 |
+
# (are we sure about this? test this)
|
53 |
+
code_lines = code.split('\n')
|
54 |
+
code_lines = [c for c in code_lines if c.strip() != '']
|
55 |
+
code = '\n'.join(code_lines)
|
56 |
+
|
57 |
+
# Add end command (we'll be listening for this so we know when it ends)
|
58 |
+
code += '\n\nprint("##end_of_execution##")'
|
59 |
+
|
60 |
+
return code
|
61 |
+
|
62 |
+
|
63 |
+
def add_active_line_prints(code):
|
64 |
+
"""
|
65 |
+
Add print statements indicating line numbers to a python string.
|
66 |
+
"""
|
67 |
+
tree = ast.parse(code)
|
68 |
+
transformer = AddLinePrints()
|
69 |
+
new_tree = transformer.visit(tree)
|
70 |
+
return ast.unparse(new_tree)
|
71 |
+
|
72 |
+
|
73 |
+
class AddLinePrints(ast.NodeTransformer):
|
74 |
+
"""
|
75 |
+
Transformer to insert print statements indicating the line number
|
76 |
+
before every executable line in the AST.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def insert_print_statement(self, line_number):
|
80 |
+
"""Inserts a print statement for a given line number."""
|
81 |
+
return ast.Expr(
|
82 |
+
value=ast.Call(
|
83 |
+
func=ast.Name(id='print', ctx=ast.Load()),
|
84 |
+
args=[ast.Constant(value=f'##active_line{line_number}##')],
|
85 |
+
keywords=[],
|
86 |
+
))
|
87 |
+
|
88 |
+
def process_body(self, body):
|
89 |
+
"""Processes a block of statements, adding print calls."""
|
90 |
+
new_body = []
|
91 |
+
|
92 |
+
# In case it's not iterable:
|
93 |
+
if not isinstance(body, list):
|
94 |
+
body = [body]
|
95 |
+
|
96 |
+
for sub_node in body:
|
97 |
+
if hasattr(sub_node, 'lineno'):
|
98 |
+
new_body.append(self.insert_print_statement(sub_node.lineno))
|
99 |
+
new_body.append(sub_node)
|
100 |
+
|
101 |
+
return new_body
|
102 |
+
|
103 |
+
def visit(self, node):
|
104 |
+
"""Overridden visit to transform nodes."""
|
105 |
+
new_node = super().visit(node)
|
106 |
+
|
107 |
+
# If node has a body, process it
|
108 |
+
if hasattr(new_node, 'body'):
|
109 |
+
new_node.body = self.process_body(new_node.body)
|
110 |
+
|
111 |
+
# If node has an orelse block (like in for, while, if), process it
|
112 |
+
if hasattr(new_node, 'orelse') and new_node.orelse:
|
113 |
+
new_node.orelse = self.process_body(new_node.orelse)
|
114 |
+
|
115 |
+
# Special case for Try nodes as they have multiple blocks
|
116 |
+
if isinstance(new_node, ast.Try):
|
117 |
+
for handler in new_node.handlers:
|
118 |
+
handler.body = self.process_body(handler.body)
|
119 |
+
if new_node.finalbody:
|
120 |
+
new_node.finalbody = self.process_body(new_node.finalbody)
|
121 |
+
|
122 |
+
return new_node
|
123 |
+
|
124 |
+
|
125 |
+
def wrap_in_try_except(code):
|
126 |
+
# Add import traceback
|
127 |
+
code = 'import traceback\n' + code
|
128 |
+
|
129 |
+
# Parse the input code into an AST
|
130 |
+
parsed_code = ast.parse(code)
|
131 |
+
|
132 |
+
# Wrap the entire code's AST in a single try-except block
|
133 |
+
try_except = ast.Try(
|
134 |
+
body=parsed_code.body,
|
135 |
+
handlers=[
|
136 |
+
ast.ExceptHandler(
|
137 |
+
type=ast.Name(id='Exception', ctx=ast.Load()),
|
138 |
+
name=None,
|
139 |
+
body=[
|
140 |
+
ast.Expr(
|
141 |
+
value=ast.Call(
|
142 |
+
func=ast.Attribute(
|
143 |
+
value=ast.Name(id='traceback', ctx=ast.Load()),
|
144 |
+
attr='print_exc',
|
145 |
+
ctx=ast.Load(),
|
146 |
+
),
|
147 |
+
args=[],
|
148 |
+
keywords=[],
|
149 |
+
)),
|
150 |
+
],
|
151 |
+
)
|
152 |
+
],
|
153 |
+
orelse=[],
|
154 |
+
finalbody=[],
|
155 |
+
)
|
156 |
+
|
157 |
+
# Assign the try-except block as the new body
|
158 |
+
parsed_code.body = [try_except]
|
159 |
+
|
160 |
+
# Convert the modified AST back to source code
|
161 |
+
return ast.unparse(parsed_code)
|
my_modelscope_agent/tools/code_interpreter_utils/languages/r.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from ..subprocess_code_interpreter import SubprocessCodeInterpreter
|
4 |
+
|
5 |
+
|
6 |
+
class R(SubprocessCodeInterpreter):
|
7 |
+
file_extension = 'r'
|
8 |
+
proper_name = 'R'
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.start_cmd = 'R -q --vanilla' # Start R in quiet and vanilla mode
|
13 |
+
|
14 |
+
def preprocess_code(self, code):
|
15 |
+
"""
|
16 |
+
Add active line markers
|
17 |
+
Wrap in a tryCatch for better error handling in R
|
18 |
+
Add end of execution marker
|
19 |
+
"""
|
20 |
+
|
21 |
+
lines = code.split('\n')
|
22 |
+
processed_lines = []
|
23 |
+
|
24 |
+
for i, line in enumerate(lines, 1):
|
25 |
+
# Add active line print
|
26 |
+
processed_lines.append(f'cat("##active_line{i}##\\n");{line}')
|
27 |
+
|
28 |
+
# Join lines to form the processed code
|
29 |
+
processed_code = '\n'.join(processed_lines)
|
30 |
+
|
31 |
+
# Wrap in a tryCatch for error handling and add end of execution marker
|
32 |
+
processed_code = f"""
|
33 |
+
tryCatch({{
|
34 |
+
{processed_code}
|
35 |
+
}}, error=function(e){{
|
36 |
+
cat("## execution_error ##\\n", conditionMessage(e), "\\n");
|
37 |
+
}})
|
38 |
+
cat("## end_of_execution ##\\n");
|
39 |
+
"""
|
40 |
+
# Count the number of lines of processed_code
|
41 |
+
# (R echoes all code back for some reason, but we can skip it if we track this!)
|
42 |
+
self.code_line_count = len(processed_code.split('\n')) - 1
|
43 |
+
|
44 |
+
return processed_code
|
45 |
+
|
46 |
+
def line_postprocessor(self, line):
|
47 |
+
# If the line count attribute is set and non-zero, decrement and skip the line
|
48 |
+
if hasattr(self, 'code_line_count') and self.code_line_count > 0:
|
49 |
+
self.code_line_count -= 1
|
50 |
+
return None
|
51 |
+
|
52 |
+
if re.match(r'^(\s*>>>\s*|\s*\.\.\.\s*|\s*>\s*|\s*\+\s*|\s*)$', line):
|
53 |
+
return None
|
54 |
+
if 'R version' in line: # Startup message
|
55 |
+
return None
|
56 |
+
if line.strip().startswith('[1] "') and line.endswith(
|
57 |
+
'"'): # For strings, trim quotation marks
|
58 |
+
return line[5:-1].strip()
|
59 |
+
if line.strip().startswith(
|
60 |
+
'[1]'): # Normal R output prefix for non-string outputs
|
61 |
+
return line[4:].strip()
|
62 |
+
|
63 |
+
return line
|
64 |
+
|
65 |
+
def detect_active_line(self, line):
|
66 |
+
if '##active_line' in line:
|
67 |
+
return int(line.split('##active_line')[1].split('##')[0])
|
68 |
+
return None
|
69 |
+
|
70 |
+
def detect_end_of_execution(self, line):
|
71 |
+
return '##end_of_execution##' in line or '## execution_error ##' in line
|
my_modelscope_agent/tools/code_interpreter_utils/languages/shell.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
import re
|
4 |
+
|
5 |
+
from ..subprocess_code_interpreter import SubprocessCodeInterpreter
|
6 |
+
|
7 |
+
|
8 |
+
class Shell(SubprocessCodeInterpreter):
|
9 |
+
file_extension = 'sh'
|
10 |
+
proper_name = 'Shell'
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# Determine the start command based on the platform
|
16 |
+
if platform.system() == 'Windows':
|
17 |
+
self.start_cmd = 'cmd.exe'
|
18 |
+
else:
|
19 |
+
self.start_cmd = os.environ.get('SHELL', 'bash')
|
20 |
+
|
21 |
+
def preprocess_code(self, code):
|
22 |
+
return preprocess_shell(code)
|
23 |
+
|
24 |
+
def line_postprocessor(self, line):
|
25 |
+
return line
|
26 |
+
|
27 |
+
def detect_active_line(self, line):
|
28 |
+
if '##active_line' in line:
|
29 |
+
return int(line.split('##active_line')[1].split('##')[0])
|
30 |
+
return None
|
31 |
+
|
32 |
+
def detect_end_of_execution(self, line):
|
33 |
+
return '##end_of_execution##' in line
|
34 |
+
|
35 |
+
|
36 |
+
def preprocess_shell(code):
|
37 |
+
"""
|
38 |
+
Add active line markers
|
39 |
+
Wrap in a try except (trap in shell)
|
40 |
+
Add end of execution marker
|
41 |
+
"""
|
42 |
+
|
43 |
+
# Add commands that tell us what the active line is
|
44 |
+
# if it's multiline, just skip this. soon we should make it work with multiline
|
45 |
+
if not has_multiline_commands(code):
|
46 |
+
code = add_active_line_prints(code)
|
47 |
+
|
48 |
+
# Add end command (we'll be listening for this so we know when it ends)
|
49 |
+
code += '\necho "##end_of_execution##"'
|
50 |
+
|
51 |
+
return code
|
52 |
+
|
53 |
+
|
54 |
+
def add_active_line_prints(code):
|
55 |
+
"""
|
56 |
+
Add echo statements indicating line numbers to a shell string.
|
57 |
+
"""
|
58 |
+
lines = code.split('\n')
|
59 |
+
for index, line in enumerate(lines):
|
60 |
+
# Insert the echo command before the actual line
|
61 |
+
lines[index] = f'echo "##active_line{index + 1}##"\n{line}'
|
62 |
+
return '\n'.join(lines)
|
63 |
+
|
64 |
+
|
65 |
+
def has_multiline_commands(script_text):
|
66 |
+
# Patterns that indicate a line continues
|
67 |
+
continuation_patterns = [
|
68 |
+
r'\\$', # Line continuation character at the end of the line
|
69 |
+
r'\|$', # Pipe character at the end of the line indicating a pipeline continuation
|
70 |
+
r'&&\s*$', # Logical AND at the end of the line
|
71 |
+
r'\|\|\s*$', # Logical OR at the end of the line
|
72 |
+
r'<\($', # Start of process substitution
|
73 |
+
r'\($', # Start of subshell
|
74 |
+
r'{\s*$', # Start of a block
|
75 |
+
r'\bif\b', # Start of an if statement
|
76 |
+
r'\bwhile\b', # Start of a while loop
|
77 |
+
r'\bfor\b', # Start of a for loop
|
78 |
+
r'do\s*$', # 'do' keyword for loops
|
79 |
+
r'then\s*$', # 'then' keyword for if statements
|
80 |
+
]
|
81 |
+
|
82 |
+
# Check each line for multiline patterns
|
83 |
+
for line in script_text.splitlines():
|
84 |
+
if any(
|
85 |
+
re.search(pattern, line.rstrip())
|
86 |
+
for pattern in continuation_patterns):
|
87 |
+
return True
|
88 |
+
|
89 |
+
return False
|
my_modelscope_agent/tools/code_interpreter_utils/subprocess_code_interpreter.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue
|
2 |
+
import subprocess
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
from .base_code_interpreter import BaseCodeInterpreter
|
8 |
+
|
9 |
+
|
10 |
+
class SubprocessCodeInterpreter(BaseCodeInterpreter):
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
self.start_cmd = ''
|
14 |
+
self.process = None
|
15 |
+
self.debug_mode = False
|
16 |
+
self.output_queue = queue.Queue()
|
17 |
+
self.done = threading.Event()
|
18 |
+
|
19 |
+
def detect_active_line(self, line):
|
20 |
+
return None
|
21 |
+
|
22 |
+
def detect_end_of_execution(self, line):
|
23 |
+
return None
|
24 |
+
|
25 |
+
def line_postprocessor(self, line):
|
26 |
+
return line
|
27 |
+
|
28 |
+
def preprocess_code(self, code):
|
29 |
+
"""
|
30 |
+
This needs to insert an end_of_execution marker of some kind,
|
31 |
+
which can be detected by detect_end_of_execution.
|
32 |
+
|
33 |
+
Optionally, add active line markers for detect_active_line.
|
34 |
+
"""
|
35 |
+
return code
|
36 |
+
|
37 |
+
def terminate(self):
|
38 |
+
self.process.terminate()
|
39 |
+
|
40 |
+
def start_process(self):
|
41 |
+
if self.process:
|
42 |
+
self.terminate()
|
43 |
+
|
44 |
+
self.process = subprocess.Popen(
|
45 |
+
self.start_cmd.split(),
|
46 |
+
stdin=subprocess.PIPE,
|
47 |
+
stdout=subprocess.PIPE,
|
48 |
+
stderr=subprocess.PIPE,
|
49 |
+
text=True,
|
50 |
+
bufsize=0,
|
51 |
+
universal_newlines=True,
|
52 |
+
)
|
53 |
+
threading.Thread(
|
54 |
+
target=self.handle_stream_output,
|
55 |
+
args=(self.process.stdout, False),
|
56 |
+
daemon=True,
|
57 |
+
).start()
|
58 |
+
threading.Thread(
|
59 |
+
target=self.handle_stream_output,
|
60 |
+
args=(self.process.stderr, True),
|
61 |
+
daemon=True,
|
62 |
+
).start()
|
63 |
+
|
64 |
+
def run(self, code):
|
65 |
+
retry_count = 0
|
66 |
+
max_retries = 3
|
67 |
+
|
68 |
+
# Setup
|
69 |
+
try:
|
70 |
+
code = self.preprocess_code(code)
|
71 |
+
if not self.process:
|
72 |
+
self.start_process()
|
73 |
+
except Exception as e:
|
74 |
+
print(e)
|
75 |
+
yield {'output': traceback.format_exc()}
|
76 |
+
return
|
77 |
+
|
78 |
+
while retry_count <= max_retries:
|
79 |
+
if self.debug_mode:
|
80 |
+
print(
|
81 |
+
f'(after processing) Running processed code:\n{code}\n---')
|
82 |
+
|
83 |
+
self.done.clear()
|
84 |
+
|
85 |
+
try:
|
86 |
+
self.process.stdin.write(code + '\n')
|
87 |
+
self.process.stdin.flush()
|
88 |
+
break
|
89 |
+
except Exception as e:
|
90 |
+
print(e)
|
91 |
+
if retry_count != 0:
|
92 |
+
# For UX, I like to hide this if it happens once. Obviously feels better to not see errors
|
93 |
+
# Most of the time it doesn't matter, but we should figure out why it happens frequently with:
|
94 |
+
# applescript
|
95 |
+
yield {'output': traceback.format_exc()}
|
96 |
+
yield {
|
97 |
+
'output': f'Retrying... ({retry_count}/{max_retries})'
|
98 |
+
}
|
99 |
+
yield {'output': 'Restarting process.'}
|
100 |
+
|
101 |
+
self.start_process()
|
102 |
+
|
103 |
+
retry_count += 1
|
104 |
+
if retry_count > max_retries:
|
105 |
+
yield {
|
106 |
+
'output':
|
107 |
+
'Maximum retries reached. Could not execute code.'
|
108 |
+
}
|
109 |
+
return
|
110 |
+
|
111 |
+
while True:
|
112 |
+
if not self.output_queue.empty():
|
113 |
+
yield self.output_queue.get()
|
114 |
+
else:
|
115 |
+
time.sleep(0.1)
|
116 |
+
try:
|
117 |
+
output = self.output_queue.get(
|
118 |
+
timeout=0.3) # Waits for 0.3 seconds
|
119 |
+
yield output
|
120 |
+
except queue.Empty:
|
121 |
+
if self.done.is_set():
|
122 |
+
# Try to yank 3 more times from it... maybe there's something in there...
|
123 |
+
# (I don't know if this actually helps. Maybe we just need to yank 1 more time)
|
124 |
+
for _ in range(3):
|
125 |
+
if not self.output_queue.empty():
|
126 |
+
yield self.output_queue.get()
|
127 |
+
time.sleep(0.2)
|
128 |
+
break
|
129 |
+
|
130 |
+
def handle_stream_output(self, stream, is_error_stream):
|
131 |
+
for line in iter(stream.readline, ''):
|
132 |
+
if self.debug_mode:
|
133 |
+
print(f'Received output line:\n{line}\n---')
|
134 |
+
|
135 |
+
line = self.line_postprocessor(line)
|
136 |
+
|
137 |
+
if line is None:
|
138 |
+
continue # `line = None` is the postprocessor's signal to discard completely
|
139 |
+
|
140 |
+
if self.detect_active_line(line):
|
141 |
+
active_line = self.detect_active_line(line)
|
142 |
+
self.output_queue.put({'active_line': active_line})
|
143 |
+
elif self.detect_end_of_execution(line):
|
144 |
+
self.output_queue.put({'active_line': None})
|
145 |
+
time.sleep(0.1)
|
146 |
+
self.done.set()
|
147 |
+
elif is_error_stream and 'KeyboardInterrupt' in line:
|
148 |
+
self.output_queue.put({'output': 'KeyboardInterrupt'})
|
149 |
+
time.sleep(0.1)
|
150 |
+
self.done.set()
|
151 |
+
else:
|
152 |
+
self.output_queue.put({'output': line})
|
my_modelscope_agent/tools/code_interpreter_utils/truncate_output.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def truncate_output(data, max_output_chars=2000):
|
2 |
+
needs_truncation = False
|
3 |
+
|
4 |
+
message = f'Output truncated. Showing the last {max_output_chars} characters.\n\n'
|
5 |
+
|
6 |
+
# Remove previous truncation message if it exists
|
7 |
+
if data.startswith(message):
|
8 |
+
data = data[len(message):]
|
9 |
+
needs_truncation = True
|
10 |
+
|
11 |
+
# If data exceeds max length, truncate it and add message
|
12 |
+
if len(data) > max_output_chars or needs_truncation:
|
13 |
+
data = message + data[-max_output_chars:]
|
14 |
+
|
15 |
+
return data
|
my_modelscope_agent/tools/hf_tool.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
from transformers.tools import Tool as HFTool
|
4 |
+
|
5 |
+
from .tool import Tool
|
6 |
+
|
7 |
+
|
8 |
+
class HFTool(Tool):
|
9 |
+
"""Simple wrapper for huggingface transformers tools
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, tool: HFTool, description: str, name: str,
|
14 |
+
parameters: List[Dict]):
|
15 |
+
self.tool = tool
|
16 |
+
self.description = description
|
17 |
+
self.name = name
|
18 |
+
self.parameters = parameters
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
def _local_call(self, *args, **kwargs):
|
22 |
+
return {'result': self.tool(**kwargs)}
|
my_modelscope_agent/tools/image_chat_tool.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modelscope.utils.constant import Tasks
|
2 |
+
from .pipeline_tool import ModelscopePipelineTool
|
3 |
+
|
4 |
+
|
5 |
+
class ImageChatTool(ModelscopePipelineTool):
|
6 |
+
default_model = 'damo/multi-modal_mplug_owl_multimodal-dialogue_7b'
|
7 |
+
description = '图文对话和图像描述服务,针对输入的图片和用户的文本输入,给出文本回复'
|
8 |
+
name = 'modelscope_image-chat'
|
9 |
+
parameters: list = [{
|
10 |
+
'name': 'image',
|
11 |
+
'description': '用户输入的图片',
|
12 |
+
'required': True
|
13 |
+
}, {
|
14 |
+
'name': 'text',
|
15 |
+
'description': '用户输入的文本',
|
16 |
+
'required': True
|
17 |
+
}]
|
18 |
+
task = Tasks.multimodal_dialogue
|
19 |
+
|
20 |
+
def construct_image_chat_input(self, **kwargs):
|
21 |
+
image = kwargs.pop('image', '')
|
22 |
+
text = kwargs.pop('text', '')
|
23 |
+
|
24 |
+
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
|
25 |
+
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
26 |
+
messages = {
|
27 |
+
'messages': [
|
28 |
+
{
|
29 |
+
'role': 'system',
|
30 |
+
'content': system_prompt_1 + ' ' + system_prompt_2
|
31 |
+
},
|
32 |
+
{
|
33 |
+
'role': 'user',
|
34 |
+
'content': [{
|
35 |
+
'image': image
|
36 |
+
}]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
'role': 'user',
|
40 |
+
'content': text
|
41 |
+
},
|
42 |
+
]
|
43 |
+
}
|
44 |
+
return messages
|
45 |
+
|
46 |
+
def _remote_parse_input(self, *args, **kwargs):
|
47 |
+
messages = self.construct_image_chat_input(**kwargs)
|
48 |
+
return {'input': messages}
|
49 |
+
|
50 |
+
def _local_parse_input(self, *args, **kwargs):
|
51 |
+
return (self.construct_image_chat_input(**kwargs)), {}
|
my_modelscope_agent/tools/openapi_plugin.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import json
|
6 |
+
import requests
|
7 |
+
from jsonschema import RefResolver
|
8 |
+
from pydantic import BaseModel, ValidationError
|
9 |
+
from requests.exceptions import RequestException, Timeout
|
10 |
+
|
11 |
+
from .tool import Tool
|
12 |
+
|
13 |
+
MAX_RETRY_TIMES = 3
|
14 |
+
|
15 |
+
|
16 |
+
class ParametersSchema(BaseModel):
|
17 |
+
name: str
|
18 |
+
description: str
|
19 |
+
required: Optional[bool] = True
|
20 |
+
|
21 |
+
|
22 |
+
class ToolSchema(BaseModel):
|
23 |
+
name: str
|
24 |
+
description: str
|
25 |
+
parameters: List[ParametersSchema]
|
26 |
+
|
27 |
+
|
28 |
+
class OpenAPIPluginTool(Tool):
|
29 |
+
"""
|
30 |
+
openapi schema tool
|
31 |
+
"""
|
32 |
+
name: str = 'api tool'
|
33 |
+
description: str = 'This is a api tool that ...'
|
34 |
+
parameters: list = []
|
35 |
+
|
36 |
+
def __init__(self, cfg, name):
|
37 |
+
self.name = name
|
38 |
+
self.cfg = cfg.get(self.name, {})
|
39 |
+
self.is_remote_tool = self.cfg.get('is_remote_tool', False)
|
40 |
+
# remote call
|
41 |
+
self.url = self.cfg.get('url', '')
|
42 |
+
self.token = self.cfg.get('token', '')
|
43 |
+
self.header = self.cfg.get('header', '')
|
44 |
+
self.method = self.cfg.get('method', '')
|
45 |
+
self.parameters = self.cfg.get('parameters', [])
|
46 |
+
self.description = self.cfg.get('description',
|
47 |
+
'This is a api tool that ...')
|
48 |
+
self.responses_param = self.cfg.get('responses_param', [])
|
49 |
+
try:
|
50 |
+
all_para = {
|
51 |
+
'name': self.name,
|
52 |
+
'description': self.description,
|
53 |
+
'parameters': self.parameters
|
54 |
+
}
|
55 |
+
self.tool_schema = ToolSchema(**all_para)
|
56 |
+
except ValidationError:
|
57 |
+
raise ValueError(f'Error when parsing parameters of {self.name}')
|
58 |
+
self._str = self.tool_schema.model_dump_json()
|
59 |
+
self._function = self.parse_pydantic_model_to_openai_function(all_para)
|
60 |
+
|
61 |
+
def _remote_call(self, *args, **kwargs):
|
62 |
+
if self.url == '':
|
63 |
+
raise ValueError(
|
64 |
+
f"Could not use remote call for {self.name} since this tool doesn't have a remote endpoint"
|
65 |
+
)
|
66 |
+
|
67 |
+
remote_parsed_input = json.dumps(
|
68 |
+
self._remote_parse_input(*args, **kwargs))
|
69 |
+
origin_result = None
|
70 |
+
if self.method == 'POST':
|
71 |
+
retry_times = MAX_RETRY_TIMES
|
72 |
+
while retry_times:
|
73 |
+
retry_times -= 1
|
74 |
+
try:
|
75 |
+
print(f'data: {kwargs}')
|
76 |
+
print(f'header: {self.header}')
|
77 |
+
response = requests.request(
|
78 |
+
'POST',
|
79 |
+
url=self.url,
|
80 |
+
headers=self.header,
|
81 |
+
data=remote_parsed_input)
|
82 |
+
|
83 |
+
if response.status_code != requests.codes.ok:
|
84 |
+
response.raise_for_status()
|
85 |
+
origin_result = json.loads(
|
86 |
+
response.content.decode('utf-8'))
|
87 |
+
|
88 |
+
final_result = self._parse_output(
|
89 |
+
origin_result, remote=True)
|
90 |
+
return final_result
|
91 |
+
except Timeout:
|
92 |
+
continue
|
93 |
+
except RequestException as e:
|
94 |
+
raise ValueError(
|
95 |
+
f'Remote call failed with error code: {e.response.status_code},\
|
96 |
+
error message: {e.response.content.decode("utf-8")}')
|
97 |
+
|
98 |
+
raise ValueError(
|
99 |
+
'Remote call max retry times exceeded! Please try to use local call.'
|
100 |
+
)
|
101 |
+
elif self.method == 'GET':
|
102 |
+
retry_times = MAX_RETRY_TIMES
|
103 |
+
|
104 |
+
new_url = self.url
|
105 |
+
matches = re.findall(r'\{(.*?)\}', self.url)
|
106 |
+
for match in matches:
|
107 |
+
if match in kwargs:
|
108 |
+
new_url = new_url.replace('{' + match + '}', kwargs[match])
|
109 |
+
else:
|
110 |
+
print(
|
111 |
+
f'The parameter {match} was not generated by the model.'
|
112 |
+
)
|
113 |
+
|
114 |
+
while retry_times:
|
115 |
+
retry_times -= 1
|
116 |
+
try:
|
117 |
+
print('GET:', new_url)
|
118 |
+
print('GET:', self.url)
|
119 |
+
|
120 |
+
response = requests.request(
|
121 |
+
'GET',
|
122 |
+
url=new_url,
|
123 |
+
headers=self.header,
|
124 |
+
params=remote_parsed_input)
|
125 |
+
if response.status_code != requests.codes.ok:
|
126 |
+
response.raise_for_status()
|
127 |
+
|
128 |
+
origin_result = json.loads(
|
129 |
+
response.content.decode('utf-8'))
|
130 |
+
|
131 |
+
final_result = self._parse_output(
|
132 |
+
origin_result, remote=True)
|
133 |
+
return final_result
|
134 |
+
except Timeout:
|
135 |
+
continue
|
136 |
+
except RequestException as e:
|
137 |
+
raise ValueError(
|
138 |
+
f'Remote call failed with error code: {e.response.status_code},\
|
139 |
+
error message: {e.response.content.decode("utf-8")}')
|
140 |
+
|
141 |
+
raise ValueError(
|
142 |
+
'Remote call max retry times exceeded! Please try to use local call.'
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
raise ValueError(
|
146 |
+
'Remote call method is invalid!We have POST and GET method.')
|
147 |
+
|
148 |
+
def _remote_parse_input(self, *args, **kwargs):
|
149 |
+
restored_dict = {}
|
150 |
+
for key, value in kwargs.items():
|
151 |
+
if '.' in key:
|
152 |
+
# Split keys by "." and create nested dictionary structures
|
153 |
+
keys = key.split('.')
|
154 |
+
temp_dict = restored_dict
|
155 |
+
for k in keys[:-1]:
|
156 |
+
temp_dict = temp_dict.setdefault(k, {})
|
157 |
+
temp_dict[keys[-1]] = value
|
158 |
+
else:
|
159 |
+
# f the key does not contain ".", directly store the key-value pair into restored_dict
|
160 |
+
restored_dict[key] = value
|
161 |
+
kwargs = restored_dict
|
162 |
+
print('传给tool的参数:', kwargs)
|
163 |
+
return kwargs
|
164 |
+
|
165 |
+
|
166 |
+
# openapi_schema_convert,register to tool_config.json
|
167 |
+
def extract_references(schema_content):
|
168 |
+
references = []
|
169 |
+
if isinstance(schema_content, dict):
|
170 |
+
if '$ref' in schema_content:
|
171 |
+
references.append(schema_content['$ref'])
|
172 |
+
for key, value in schema_content.items():
|
173 |
+
references.extend(extract_references(value))
|
174 |
+
elif isinstance(schema_content, list):
|
175 |
+
for item in schema_content:
|
176 |
+
references.extend(extract_references(item))
|
177 |
+
return references
|
178 |
+
|
179 |
+
|
180 |
+
def parse_nested_parameters(param_name, param_info, parameters_list, content):
|
181 |
+
param_type = param_info['type']
|
182 |
+
param_description = param_info.get('description',
|
183 |
+
f'用户输入的{param_name}') # 按需更改描述
|
184 |
+
param_required = param_name in content['required']
|
185 |
+
try:
|
186 |
+
if param_type == 'object':
|
187 |
+
properties = param_info.get('properties')
|
188 |
+
if properties:
|
189 |
+
# If the argument type is an object and has a non-empty "properties" field,
|
190 |
+
# its internal properties are parsed recursively
|
191 |
+
for inner_param_name, inner_param_info in properties.items():
|
192 |
+
inner_param_type = inner_param_info['type']
|
193 |
+
inner_param_description = inner_param_info.get(
|
194 |
+
'description', f'用户输入的{param_name}.{inner_param_name}')
|
195 |
+
inner_param_required = param_name.split(
|
196 |
+
'.')[0] in content['required']
|
197 |
+
|
198 |
+
# Recursively call the function to handle nested objects
|
199 |
+
if inner_param_type == 'object':
|
200 |
+
parse_nested_parameters(
|
201 |
+
f'{param_name}.{inner_param_name}',
|
202 |
+
inner_param_info, parameters_list, content)
|
203 |
+
else:
|
204 |
+
parameters_list.append({
|
205 |
+
'name':
|
206 |
+
f'{param_name}.{inner_param_name}',
|
207 |
+
'description':
|
208 |
+
inner_param_description,
|
209 |
+
'required':
|
210 |
+
inner_param_required,
|
211 |
+
'type':
|
212 |
+
inner_param_type,
|
213 |
+
'value':
|
214 |
+
inner_param_info.get('enum', '')
|
215 |
+
})
|
216 |
+
else:
|
217 |
+
# Non-nested parameters are added directly to the parameter list
|
218 |
+
parameters_list.append({
|
219 |
+
'name': param_name,
|
220 |
+
'description': param_description,
|
221 |
+
'required': param_required,
|
222 |
+
'type': param_type,
|
223 |
+
'value': param_info.get('enum', '')
|
224 |
+
})
|
225 |
+
except Exception as e:
|
226 |
+
raise ValueError(f'{e}:schema结构出错')
|
227 |
+
|
228 |
+
|
229 |
+
def parse_responses_parameters(param_name, param_info, parameters_list):
|
230 |
+
param_type = param_info['type']
|
231 |
+
param_description = param_info.get('description',
|
232 |
+
f'调用api返回的{param_name}') # 按需更改描述
|
233 |
+
try:
|
234 |
+
if param_type == 'object':
|
235 |
+
properties = param_info.get('properties')
|
236 |
+
if properties:
|
237 |
+
# If the argument type is an object and has a non-empty "properties"
|
238 |
+
# field, its internal properties are parsed recursively
|
239 |
+
|
240 |
+
for inner_param_name, inner_param_info in properties.items():
|
241 |
+
param_type = inner_param_info['type']
|
242 |
+
param_description = inner_param_info.get(
|
243 |
+
'description',
|
244 |
+
f'调用api返回的{param_name}.{inner_param_name}')
|
245 |
+
parameters_list.append({
|
246 |
+
'name': f'{param_name}.{inner_param_name}',
|
247 |
+
'description': param_description,
|
248 |
+
'type': param_type,
|
249 |
+
})
|
250 |
+
else:
|
251 |
+
# Non-nested parameters are added directly to the parameter list
|
252 |
+
parameters_list.append({
|
253 |
+
'name': param_name,
|
254 |
+
'description': param_description,
|
255 |
+
'type': param_type,
|
256 |
+
})
|
257 |
+
except Exception as e:
|
258 |
+
raise ValueError(f'{e}:schema结构出错')
|
259 |
+
|
260 |
+
|
261 |
+
def openapi_schema_convert(schema, auth):
|
262 |
+
|
263 |
+
resolver = RefResolver.from_schema(schema)
|
264 |
+
servers = schema.get('servers', [])
|
265 |
+
if servers:
|
266 |
+
servers_url = servers[0].get('url')
|
267 |
+
else:
|
268 |
+
print('No URL found in the schema.')
|
269 |
+
# Extract endpoints
|
270 |
+
endpoints = schema.get('paths', {})
|
271 |
+
description = schema.get('info', {}).get('description',
|
272 |
+
'This is a api tool that ...')
|
273 |
+
config_data = {}
|
274 |
+
# Iterate over each endpoint and its contents
|
275 |
+
for endpoint_path, methods in endpoints.items():
|
276 |
+
for method, details in methods.items():
|
277 |
+
summary = details.get('summary', 'No summary').replace(' ', '_')
|
278 |
+
name = details.get('operationId', 'No operationId')
|
279 |
+
url = f'{servers_url}{endpoint_path}'
|
280 |
+
security = details.get('security', [{}])
|
281 |
+
# Security (Bearer Token)
|
282 |
+
authorization = ''
|
283 |
+
if security:
|
284 |
+
for sec in security:
|
285 |
+
if 'BearerAuth' in sec:
|
286 |
+
api_token = auth.get('apikey', os.environ['apikey'])
|
287 |
+
api_token_type = auth.get('apikey_type',
|
288 |
+
os.environ['apikey_type'])
|
289 |
+
authorization = f'{api_token_type} {api_token}'
|
290 |
+
if method.upper() == 'POST':
|
291 |
+
requestBody = details.get('requestBody', {})
|
292 |
+
if requestBody:
|
293 |
+
for content_type, content_details in requestBody.get(
|
294 |
+
'content', {}).items():
|
295 |
+
schema_content = content_details.get('schema', {})
|
296 |
+
references = extract_references(schema_content)
|
297 |
+
for reference in references:
|
298 |
+
resolved_schema = resolver.resolve(reference)
|
299 |
+
content = resolved_schema[1]
|
300 |
+
parameters_list = []
|
301 |
+
for param_name, param_info in content[
|
302 |
+
'properties'].items():
|
303 |
+
parse_nested_parameters(
|
304 |
+
param_name, param_info, parameters_list,
|
305 |
+
content)
|
306 |
+
X_DashScope_Async = requestBody.get(
|
307 |
+
'X-DashScope-Async', '')
|
308 |
+
if X_DashScope_Async == '':
|
309 |
+
config_entry = {
|
310 |
+
'name': name,
|
311 |
+
'description': description,
|
312 |
+
'is_active': True,
|
313 |
+
'is_remote_tool': True,
|
314 |
+
'url': url,
|
315 |
+
'method': method.upper(),
|
316 |
+
'parameters': parameters_list,
|
317 |
+
'header': {
|
318 |
+
'Content-Type': content_type,
|
319 |
+
'Authorization': authorization
|
320 |
+
}
|
321 |
+
}
|
322 |
+
else:
|
323 |
+
config_entry = {
|
324 |
+
'name': name,
|
325 |
+
'description': description,
|
326 |
+
'is_active': True,
|
327 |
+
'is_remote_tool': True,
|
328 |
+
'url': url,
|
329 |
+
'method': method.upper(),
|
330 |
+
'parameters': parameters_list,
|
331 |
+
'header': {
|
332 |
+
'Content-Type': content_type,
|
333 |
+
'Authorization': authorization,
|
334 |
+
'X-DashScope-Async': 'enable'
|
335 |
+
}
|
336 |
+
}
|
337 |
+
else:
|
338 |
+
config_entry = {
|
339 |
+
'name': name,
|
340 |
+
'description': description,
|
341 |
+
'is_active': True,
|
342 |
+
'is_remote_tool': True,
|
343 |
+
'url': url,
|
344 |
+
'method': method.upper(),
|
345 |
+
'parameters': [],
|
346 |
+
'header': {
|
347 |
+
'Content-Type': 'application/json',
|
348 |
+
'Authorization': authorization
|
349 |
+
}
|
350 |
+
}
|
351 |
+
elif method.upper() == 'GET':
|
352 |
+
parameters_list = []
|
353 |
+
parameters_list = details.get('parameters', [])
|
354 |
+
config_entry = {
|
355 |
+
'name': name,
|
356 |
+
'description': description,
|
357 |
+
'is_active': True,
|
358 |
+
'is_remote_tool': True,
|
359 |
+
'url': url,
|
360 |
+
'method': method.upper(),
|
361 |
+
'parameters': parameters_list,
|
362 |
+
'header': {
|
363 |
+
'Authorization': authorization
|
364 |
+
}
|
365 |
+
}
|
366 |
+
else:
|
367 |
+
raise 'method is not POST or GET'
|
368 |
+
|
369 |
+
config_data[summary] = config_entry
|
370 |
+
return config_data
|
my_modelscope_agent/tools/pipeline_tool.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modelscope.pipelines import pipeline
|
2 |
+
from .tool import Tool
|
3 |
+
|
4 |
+
|
5 |
+
class ModelscopePipelineTool(Tool):
|
6 |
+
|
7 |
+
default_model: str = ''
|
8 |
+
task: str = ''
|
9 |
+
model_revision = None
|
10 |
+
|
11 |
+
def __init__(self, cfg):
|
12 |
+
|
13 |
+
super().__init__(cfg)
|
14 |
+
self.model = self.cfg.get('model', None) or self.default_model
|
15 |
+
self.model_revision = self.cfg.get('model_revision',
|
16 |
+
None) or self.model_revision
|
17 |
+
|
18 |
+
self.pipeline_params = self.cfg.get('pipeline_params', {})
|
19 |
+
self.pipeline = None
|
20 |
+
self.is_initialized = False
|
21 |
+
|
22 |
+
def setup(self):
|
23 |
+
|
24 |
+
# only initialize when this tool is really called to save memory
|
25 |
+
if not self.is_initialized:
|
26 |
+
self.pipeline = pipeline(
|
27 |
+
task=self.task,
|
28 |
+
model=self.model,
|
29 |
+
model_revision=self.model_revision,
|
30 |
+
**self.pipeline_params)
|
31 |
+
self.is_initialized = True
|
32 |
+
|
33 |
+
def _local_call(self, *args, **kwargs):
|
34 |
+
|
35 |
+
self.setup()
|
36 |
+
|
37 |
+
parsed_args, parsed_kwargs = self._local_parse_input(*args, **kwargs)
|
38 |
+
origin_result = self.pipeline(*parsed_args, **parsed_kwargs)
|
39 |
+
final_result = self._parse_output(origin_result, remote=False)
|
40 |
+
return final_result
|
my_modelscope_agent/tools/plugin_tool.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
from .tool import Tool
|
4 |
+
|
5 |
+
|
6 |
+
class LangchainTool(Tool):
|
7 |
+
|
8 |
+
def __init__(self, langchain_tool):
|
9 |
+
from langchain.tools import BaseTool
|
10 |
+
|
11 |
+
if not isinstance(langchain_tool, BaseTool):
|
12 |
+
raise ValueError('langchain_tool should be type of langchain tool')
|
13 |
+
self.langchain_tool = langchain_tool
|
14 |
+
self.parse_langchain_schema()
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def parse_langchain_schema(self):
|
18 |
+
# convert langchain tool schema to modelscope_agent tool schema
|
19 |
+
self.description = self.langchain_tool.description
|
20 |
+
self.name = self.langchain_tool.name
|
21 |
+
self.parameters = []
|
22 |
+
for name, arg in self.langchain_tool.args.items():
|
23 |
+
tool_arg = deepcopy(arg)
|
24 |
+
tool_arg['name'] = name
|
25 |
+
tool_arg['required'] = True
|
26 |
+
tool_arg.pop('title')
|
27 |
+
self.parameters.append(tool_arg)
|
28 |
+
|
29 |
+
def _local_call(self, *args, **kwargs):
|
30 |
+
return {'result': self.langchain_tool.run(kwargs)}
|
my_modelscope_agent/tools/text_address_tool.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modelscope.utils.constant import Tasks
|
2 |
+
from .pipeline_tool import ModelscopePipelineTool
|
3 |
+
|
4 |
+
|
5 |
+
class TextAddressTool(ModelscopePipelineTool):
|
6 |
+
default_model = 'damo/mgeo_geographic_elements_tagging_chinese_base'
|
7 |
+
description = '地址解析服务,针对中文地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等'
|
8 |
+
name = 'modelscope_text-address'
|
9 |
+
parameters: list = [{
|
10 |
+
'name': 'input',
|
11 |
+
'description': '用户输入的地址信息',
|
12 |
+
'required': True
|
13 |
+
}]
|
14 |
+
task = Tasks.token_classification
|
15 |
+
|
16 |
+
def _parse_output(self, origin_result, *args, **kwargs):
|
17 |
+
final_result = {}
|
18 |
+
for e in origin_result['output']:
|
19 |
+
final_result[e['type']] = e['span']
|
20 |
+
return {'result': final_result}
|
my_modelscope_agent/tools/text_ie_tool.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
from modelscope.utils.constant import Tasks
|
4 |
+
from .pipeline_tool import ModelscopePipelineTool
|
5 |
+
|
6 |
+
|
7 |
+
class TextInfoExtractTool(ModelscopePipelineTool):
|
8 |
+
default_model = 'damo/nlp_structbert_siamese-uie_chinese-base'
|
9 |
+
description = '信息抽取服务,针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示'
|
10 |
+
name = 'modelscope_text-ie'
|
11 |
+
parameters: list = [{
|
12 |
+
'name': 'input',
|
13 |
+
'description': '用户输入的文本',
|
14 |
+
'required': True
|
15 |
+
}, {
|
16 |
+
'name': 'schema',
|
17 |
+
'description': '要抽取信息的json表示',
|
18 |
+
'required': True
|
19 |
+
}]
|
20 |
+
task = Tasks.siamese_uie
|
21 |
+
|
22 |
+
def _remote_parse_input(self, *args, **kwargs):
|
23 |
+
kwargs['parameters'] = {'schema': kwargs['schema']}
|
24 |
+
kwargs.pop('schema')
|
25 |
+
return kwargs
|
26 |
+
|
27 |
+
def _parse_output(self, origin_result, *args, **kwargs):
|
28 |
+
final_result = defaultdict(list)
|
29 |
+
for e in origin_result['output']:
|
30 |
+
final_result[e[0]['type']].append(e[0]['span'])
|
31 |
+
|
32 |
+
return {'result': dict(final_result)}
|