File size: 13,028 Bytes
4049862 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
# 代码主要来源于 https://github.com/Shawn-Inspur/Yuan-1.0/blob/main/yuan_api/inspurai.py
import hashlib
import json
import os
import time
import uuid
from datetime import datetime
import pytz
import requests
from modules.presets import NO_APIKEY_MSG
from modules.models.base_model import BaseLLMModel
class Example:
""" store some examples(input, output pairs and formats) for few-shots to prime the model."""
def __init__(self, inp, out):
self.input = inp
self.output = out
self.id = uuid.uuid4().hex
def get_input(self):
"""return the input of the example."""
return self.input
def get_output(self):
"""Return the output of the example."""
return self.output
def get_id(self):
"""Returns the unique ID of the example."""
return self.id
def as_dict(self):
return {
"input": self.get_input(),
"output": self.get_output(),
"id": self.get_id(),
}
class Yuan:
"""The main class for a user to interface with the Inspur Yuan API.
A user can set account info and add examples of the API request.
"""
def __init__(self,
engine='base_10B',
temperature=0.9,
max_tokens=100,
input_prefix='',
input_suffix='\n',
output_prefix='答:',
output_suffix='\n\n',
append_output_prefix_to_query=False,
topK=1,
topP=0.9,
frequencyPenalty=1.2,
responsePenalty=1.2,
noRepeatNgramSize=2):
self.examples = {}
self.engine = engine
self.temperature = temperature
self.max_tokens = max_tokens
self.topK = topK
self.topP = topP
self.frequencyPenalty = frequencyPenalty
self.responsePenalty = responsePenalty
self.noRepeatNgramSize = noRepeatNgramSize
self.input_prefix = input_prefix
self.input_suffix = input_suffix
self.output_prefix = output_prefix
self.output_suffix = output_suffix
self.append_output_prefix_to_query = append_output_prefix_to_query
self.stop = (output_suffix + input_prefix).strip()
self.api = None
# if self.engine not in ['base_10B','translate','dialog']:
# raise Exception('engine must be one of [\'base_10B\',\'translate\',\'dialog\'] ')
def set_account(self, api_key):
account = api_key.split('||')
self.api = YuanAPI(user=account[0], phone=account[1])
def add_example(self, ex):
"""Add an example to the object.
Example must be an instance of the Example class."""
assert isinstance(ex, Example), "Please create an Example object."
self.examples[ex.get_id()] = ex
def delete_example(self, id):
"""Delete example with the specific id."""
if id in self.examples:
del self.examples[id]
def get_example(self, id):
"""Get a single example."""
return self.examples.get(id, None)
def get_all_examples(self):
"""Returns all examples as a list of dicts."""
return {k: v.as_dict() for k, v in self.examples.items()}
def get_prime_text(self):
"""Formats all examples to prime the model."""
return "".join(
[self.format_example(ex) for ex in self.examples.values()])
def get_engine(self):
"""Returns the engine specified for the API."""
return self.engine
def get_temperature(self):
"""Returns the temperature specified for the API."""
return self.temperature
def get_max_tokens(self):
"""Returns the max tokens specified for the API."""
return self.max_tokens
def craft_query(self, prompt):
"""Creates the query for the API request."""
q = self.get_prime_text(
) + self.input_prefix + prompt + self.input_suffix
if self.append_output_prefix_to_query:
q = q + self.output_prefix
return q
def format_example(self, ex):
"""Formats the input, output pair."""
return self.input_prefix + ex.get_input(
) + self.input_suffix + self.output_prefix + ex.get_output(
) + self.output_suffix
def response(self,
query,
engine='base_10B',
max_tokens=20,
temperature=0.9,
topP=0.1,
topK=1,
frequencyPenalty=1.0,
responsePenalty=1.0,
noRepeatNgramSize=0):
"""Obtains the original result returned by the API."""
if self.api is None:
return NO_APIKEY_MSG
try:
# requestId = submit_request(query,temperature,topP,topK,max_tokens, engine)
requestId = self.api.submit_request(query, temperature, topP, topK, max_tokens, engine, frequencyPenalty,
responsePenalty, noRepeatNgramSize)
response_text = self.api.reply_request(requestId)
except Exception as e:
raise e
return response_text
def del_special_chars(self, msg):
special_chars = ['<unk>', '<eod>', '#', '▃', '▁', '▂', ' ']
for char in special_chars:
msg = msg.replace(char, '')
return msg
def submit_API(self, prompt, trun=[]):
"""Submit prompt to yuan API interface and obtain an pure text reply.
:prompt: Question or any content a user may input.
:return: pure text response."""
query = self.craft_query(prompt)
res = self.response(query, engine=self.engine,
max_tokens=self.max_tokens,
temperature=self.temperature,
topP=self.topP,
topK=self.topK,
frequencyPenalty=self.frequencyPenalty,
responsePenalty=self.responsePenalty,
noRepeatNgramSize=self.noRepeatNgramSize)
if 'resData' in res and res['resData'] != None:
txt = res['resData']
else:
txt = '模型返回为空,请尝试修改输入'
# 单独针对翻译模型的后处理
if self.engine == 'translate':
txt = txt.replace(' ##', '').replace(' "', '"').replace(": ", ":").replace(" ,", ",") \
.replace('英文:', '').replace('文:', '').replace("( ", "(").replace(" )", ")")
else:
txt = txt.replace(' ', '')
txt = self.del_special_chars(txt)
# trun多结束符截断模型输出
if isinstance(trun, str):
trun = [trun]
try:
if trun != None and isinstance(trun, list) and trun != []:
for tr in trun:
if tr in txt and tr != "":
txt = txt[:txt.index(tr)]
else:
continue
except:
return txt
return txt
class YuanAPI:
ACCOUNT = ''
PHONE = ''
SUBMIT_URL = "http://api.airyuan.cn:32102/v1/interface/api/infer/getRequestId?"
REPLY_URL = "http://api.airyuan.cn:32102/v1/interface/api/result?"
def __init__(self, user, phone):
self.ACCOUNT = user
self.PHONE = phone
@staticmethod
def code_md5(str):
code = str.encode("utf-8")
m = hashlib.md5()
m.update(code)
result = m.hexdigest()
return result
@staticmethod
def rest_get(url, header, timeout, show_error=False):
'''Call rest get method'''
try:
response = requests.get(url, headers=header, timeout=timeout, verify=False)
return response
except Exception as exception:
if show_error:
print(exception)
return None
def header_generation(self):
"""Generate header for API request."""
t = datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d")
token = self.code_md5(self.ACCOUNT + self.PHONE + t)
headers = {'token': token}
return headers
def submit_request(self, query, temperature, topP, topK, max_tokens, engine, frequencyPenalty, responsePenalty,
noRepeatNgramSize):
"""Submit query to the backend server and get requestID."""
headers = self.header_generation()
# url=SUBMIT_URL + "account={0}&data={1}&temperature={2}&topP={3}&topK={4}&tokensToGenerate={5}&type={6}".format(ACCOUNT,query,temperature,topP,topK,max_tokens,"api")
# url=SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \
# "&type={7}".format(engine,ACCOUNT,query,temperature,topP,topK, max_tokens,"api")
url = self.SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \
"&type={7}&frequencyPenalty={8}&responsePenalty={9}&noRepeatNgramSize={10}". \
format(engine, self.ACCOUNT, query, temperature, topP, topK, max_tokens, "api", frequencyPenalty,
responsePenalty, noRepeatNgramSize)
response = self.rest_get(url, headers, 30)
response_text = json.loads(response.text)
if response_text["flag"]:
requestId = response_text["resData"]
return requestId
else:
raise RuntimeWarning(response_text)
def reply_request(self, requestId, cycle_count=5):
"""Check reply API to get the inference response."""
url = self.REPLY_URL + "account={0}&requestId={1}".format(self.ACCOUNT, requestId)
headers = self.header_generation()
response_text = {"flag": True, "resData": None}
for i in range(cycle_count):
response = self.rest_get(url, headers, 30, show_error=True)
response_text = json.loads(response.text)
if response_text["resData"] is not None:
return response_text
if response_text["flag"] is False and i == cycle_count - 1:
raise RuntimeWarning(response_text)
time.sleep(3)
return response_text
class Yuan_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="", system_prompt=None):
super().__init__(model_name=model_name, user=user_name)
self.history = []
self.api_key = api_key
self.system_prompt = system_prompt
self.input_prefix = ""
self.output_prefix = ""
def set_text_prefix(self, option, value):
if option == 'input_prefix':
self.input_prefix = value
elif option == 'output_prefix':
self.output_prefix = value
def get_answer_at_once(self):
# yuan temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
temperature = self.temperature if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
topP = self.top_p
topK = self.n_choices
# max_tokens should be in [1,200]
max_tokens = self.max_generation_token if self.max_generation_token is not None else 50
if max_tokens > 200:
max_tokens = 200
stop = self.stop_sequence if self.stop_sequence is not None else []
examples = []
system_prompt = self.system_prompt
if system_prompt is not None:
lines = system_prompt.splitlines()
# TODO: support prefixes in system prompt or settings
"""
if lines[0].startswith('-'):
prefixes = lines.pop()[1:].split('|')
self.input_prefix = prefixes[0]
if len(prefixes) > 1:
self.output_prefix = prefixes[1]
if len(prefixes) > 2:
stop = prefixes[2].split(',')
"""
for i in range(0, len(lines), 2):
in_line = lines[i]
out_line = lines[i + 1] if i + 1 < len(lines) else ""
examples.append((in_line, out_line))
yuan = Yuan(engine=self.model_name.replace('yuanai-1.0-', ''),
temperature=temperature,
max_tokens=max_tokens,
topK=topK,
topP=topP,
input_prefix=self.input_prefix,
input_suffix="",
output_prefix=self.output_prefix,
output_suffix="".join(stop),
)
if not self.api_key:
return NO_APIKEY_MSG, 0
yuan.set_account(self.api_key)
for in_line, out_line in examples:
yuan.add_example(Example(inp=in_line, out=out_line))
prompt = self.history[-1]["content"]
answer = yuan.submit_API(prompt, trun=stop)
return answer, len(answer)
|