Spaces:
Sleeping
Sleeping
File size: 5,799 Bytes
09321b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
from typing import List, Optional
import json
import requests
from pydantic import BaseModel, ValidationError
from requests.exceptions import RequestException, Timeout
MODELSCOPE_API_TOKEN = os.getenv('MODELSCOPE_API_TOKEN')
MAX_RETRY_TIMES = 3
class ParametersSchema(BaseModel):
name: str
description: str
required: Optional[bool] = True
class ToolSchema(BaseModel):
name: str
description: str
parameters: List[ParametersSchema]
class Tool:
"""
a base class for tools.
when you inherit this class and implement new tool, you should provide name, description
and parameters of tool that conforms with schema.
each tool may have two call method: _local_call(execute tool in your local environment)
and _remote_call(construct a http request to remote server).
corresponding to preprocess and postprocess method may need to be overrided to get correct result.
"""
name: str = 'tool'
description: str = 'This is a tool that ...'
parameters: list = []
def __init__(self, cfg={}):
self.cfg = cfg.get(self.name, {})
self.is_remote_tool = self.cfg.get('is_remote_tool', False)
# remote call
self.url = self.cfg.get('url', '')
self.token = self.cfg.get('token', '')
self.header = {
'Authorization': self.token or f'Bearer {MODELSCOPE_API_TOKEN}'
}
try:
all_para = {
'name': self.name,
'description': self.description,
'parameters': self.parameters
}
self.tool_schema = ToolSchema(**all_para)
except ValidationError:
raise ValueError(f'Error when parsing parameters of {self.name}')
self._str = self.tool_schema.model_dump_json()
self._function = self.parse_pydantic_model_to_openai_function(all_para)
def __call__(self, remote=False, *args, **kwargs):
if self.is_remote_tool or remote:
return self._remote_call(*args, **kwargs)
else:
return self._local_call(*args, **kwargs)
def _remote_call(self, *args, **kwargs):
if self.url == '':
raise ValueError(
f"Could not use remote call for {self.name} since this tool doesn't have a remote endpoint"
)
remote_parsed_input = json.dumps(
self._remote_parse_input(*args, **kwargs))
origin_result = None
retry_times = MAX_RETRY_TIMES
while retry_times:
retry_times -= 1
try:
response = requests.request(
'POST',
self.url,
headers=self.header,
data=remote_parsed_input)
if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(
response.content.decode('utf-8'))['Data']
final_result = self._parse_output(origin_result, remote=True)
return final_result
except Timeout:
continue
except RequestException as e:
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')
raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)
def _local_call(self, *args, **kwargs):
return
def _remote_parse_input(self, *args, **kwargs):
return kwargs
def _local_parse_input(self, *args, **kwargs):
return args, kwargs
def _parse_output(self, origin_result, *args, **kwargs):
return {'result': origin_result}
def __str__(self):
return self._str
def get_function(self):
return self._function
def parse_pydantic_model_to_openai_function(self, all_para: dict):
'''
this method used to convert a pydantic model to openai function schema
such that convert
all_para = {
'name': get_current_weather,
'description': Get the current weather in a given location,
'parameters': [{
'name': 'image',
'description': '用户输入的图片',
'required': True
}, {
'name': 'text',
'description': '用户输入的文本',
'required': True
}]
}
to
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"image": {
"type": "string",
"description": "用户输入的图片",
},
"text": {
"type": "string",
"description": "用户输入的文本",
},
"required": ["image", "text"],
},
}
'''
function = {
'name': all_para['name'],
'description': all_para['description'],
'parameters': {
'type': 'object',
'properties': {},
'required': [],
},
}
for para in all_para['parameters']:
function['parameters']['properties'][para['name']] = {
'type': 'string',
'description': para['description']
}
if para['required']:
function['parameters']['required'].append(para['name'])
return function
|