Komorebi / modules /models /inspurai.py
meteor-2023's picture
Duplicate from JohnSmith9982/ChuanhuChatGPT
3678cf8
raw
history blame
13 kB
# 代码主要来源于 https://github.com/Shawn-Inspur/Yuan-1.0/blob/main/yuan_api/inspurai.py
import hashlib
import json
import os
import time
import uuid
from datetime import datetime
import pytz
import requests
from modules.presets import NO_APIKEY_MSG
from modules.models.base_model import BaseLLMModel
class Example:
""" store some examples(input, output pairs and formats) for few-shots to prime the model."""
def __init__(self, inp, out):
self.input = inp
self.output = out
self.id = uuid.uuid4().hex
def get_input(self):
"""return the input of the example."""
return self.input
def get_output(self):
"""Return the output of the example."""
return self.output
def get_id(self):
"""Returns the unique ID of the example."""
return self.id
def as_dict(self):
return {
"input": self.get_input(),
"output": self.get_output(),
"id": self.get_id(),
}
class Yuan:
"""The main class for a user to interface with the Inspur Yuan API.
A user can set account info and add examples of the API request.
"""
def __init__(self,
engine='base_10B',
temperature=0.9,
max_tokens=100,
input_prefix='',
input_suffix='\n',
output_prefix='答:',
output_suffix='\n\n',
append_output_prefix_to_query=False,
topK=1,
topP=0.9,
frequencyPenalty=1.2,
responsePenalty=1.2,
noRepeatNgramSize=2):
self.examples = {}
self.engine = engine
self.temperature = temperature
self.max_tokens = max_tokens
self.topK = topK
self.topP = topP
self.frequencyPenalty = frequencyPenalty
self.responsePenalty = responsePenalty
self.noRepeatNgramSize = noRepeatNgramSize
self.input_prefix = input_prefix
self.input_suffix = input_suffix
self.output_prefix = output_prefix
self.output_suffix = output_suffix
self.append_output_prefix_to_query = append_output_prefix_to_query
self.stop = (output_suffix + input_prefix).strip()
self.api = None
# if self.engine not in ['base_10B','translate','dialog']:
# raise Exception('engine must be one of [\'base_10B\',\'translate\',\'dialog\'] ')
def set_account(self, api_key):
account = api_key.split('||')
self.api = YuanAPI(user=account[0], phone=account[1])
def add_example(self, ex):
"""Add an example to the object.
Example must be an instance of the Example class."""
assert isinstance(ex, Example), "Please create an Example object."
self.examples[ex.get_id()] = ex
def delete_example(self, id):
"""Delete example with the specific id."""
if id in self.examples:
del self.examples[id]
def get_example(self, id):
"""Get a single example."""
return self.examples.get(id, None)
def get_all_examples(self):
"""Returns all examples as a list of dicts."""
return {k: v.as_dict() for k, v in self.examples.items()}
def get_prime_text(self):
"""Formats all examples to prime the model."""
return "".join(
[self.format_example(ex) for ex in self.examples.values()])
def get_engine(self):
"""Returns the engine specified for the API."""
return self.engine
def get_temperature(self):
"""Returns the temperature specified for the API."""
return self.temperature
def get_max_tokens(self):
"""Returns the max tokens specified for the API."""
return self.max_tokens
def craft_query(self, prompt):
"""Creates the query for the API request."""
q = self.get_prime_text(
) + self.input_prefix + prompt + self.input_suffix
if self.append_output_prefix_to_query:
q = q + self.output_prefix
return q
def format_example(self, ex):
"""Formats the input, output pair."""
return self.input_prefix + ex.get_input(
) + self.input_suffix + self.output_prefix + ex.get_output(
) + self.output_suffix
def response(self,
query,
engine='base_10B',
max_tokens=20,
temperature=0.9,
topP=0.1,
topK=1,
frequencyPenalty=1.0,
responsePenalty=1.0,
noRepeatNgramSize=0):
"""Obtains the original result returned by the API."""
if self.api is None:
return NO_APIKEY_MSG
try:
# requestId = submit_request(query,temperature,topP,topK,max_tokens, engine)
requestId = self.api.submit_request(query, temperature, topP, topK, max_tokens, engine, frequencyPenalty,
responsePenalty, noRepeatNgramSize)
response_text = self.api.reply_request(requestId)
except Exception as e:
raise e
return response_text
def del_special_chars(self, msg):
special_chars = ['<unk>', '<eod>', '#', '▃', '▁', '▂', ' ']
for char in special_chars:
msg = msg.replace(char, '')
return msg
def submit_API(self, prompt, trun=[]):
"""Submit prompt to yuan API interface and obtain an pure text reply.
:prompt: Question or any content a user may input.
:return: pure text response."""
query = self.craft_query(prompt)
res = self.response(query, engine=self.engine,
max_tokens=self.max_tokens,
temperature=self.temperature,
topP=self.topP,
topK=self.topK,
frequencyPenalty=self.frequencyPenalty,
responsePenalty=self.responsePenalty,
noRepeatNgramSize=self.noRepeatNgramSize)
if 'resData' in res and res['resData'] != None:
txt = res['resData']
else:
txt = '模型返回为空,请尝试修改输入'
# 单独针对翻译模型的后处理
if self.engine == 'translate':
txt = txt.replace(' ##', '').replace(' "', '"').replace(": ", ":").replace(" ,", ",") \
.replace('英文:', '').replace('文:', '').replace("( ", "(").replace(" )", ")")
else:
txt = txt.replace(' ', '')
txt = self.del_special_chars(txt)
# trun多结束符截断模型输出
if isinstance(trun, str):
trun = [trun]
try:
if trun != None and isinstance(trun, list) and trun != []:
for tr in trun:
if tr in txt and tr != "":
txt = txt[:txt.index(tr)]
else:
continue
except:
return txt
return txt
class YuanAPI:
ACCOUNT = ''
PHONE = ''
SUBMIT_URL = "http://api.airyuan.cn:32102/v1/interface/api/infer/getRequestId?"
REPLY_URL = "http://api.airyuan.cn:32102/v1/interface/api/result?"
def __init__(self, user, phone):
self.ACCOUNT = user
self.PHONE = phone
@staticmethod
def code_md5(str):
code = str.encode("utf-8")
m = hashlib.md5()
m.update(code)
result = m.hexdigest()
return result
@staticmethod
def rest_get(url, header, timeout, show_error=False):
'''Call rest get method'''
try:
response = requests.get(url, headers=header, timeout=timeout, verify=False)
return response
except Exception as exception:
if show_error:
print(exception)
return None
def header_generation(self):
"""Generate header for API request."""
t = datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d")
token = self.code_md5(self.ACCOUNT + self.PHONE + t)
headers = {'token': token}
return headers
def submit_request(self, query, temperature, topP, topK, max_tokens, engine, frequencyPenalty, responsePenalty,
noRepeatNgramSize):
"""Submit query to the backend server and get requestID."""
headers = self.header_generation()
# url=SUBMIT_URL + "account={0}&data={1}&temperature={2}&topP={3}&topK={4}&tokensToGenerate={5}&type={6}".format(ACCOUNT,query,temperature,topP,topK,max_tokens,"api")
# url=SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \
# "&type={7}".format(engine,ACCOUNT,query,temperature,topP,topK, max_tokens,"api")
url = self.SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \
"&type={7}&frequencyPenalty={8}&responsePenalty={9}&noRepeatNgramSize={10}". \
format(engine, self.ACCOUNT, query, temperature, topP, topK, max_tokens, "api", frequencyPenalty,
responsePenalty, noRepeatNgramSize)
response = self.rest_get(url, headers, 30)
response_text = json.loads(response.text)
if response_text["flag"]:
requestId = response_text["resData"]
return requestId
else:
raise RuntimeWarning(response_text)
def reply_request(self, requestId, cycle_count=5):
"""Check reply API to get the inference response."""
url = self.REPLY_URL + "account={0}&requestId={1}".format(self.ACCOUNT, requestId)
headers = self.header_generation()
response_text = {"flag": True, "resData": None}
for i in range(cycle_count):
response = self.rest_get(url, headers, 30, show_error=True)
response_text = json.loads(response.text)
if response_text["resData"] is not None:
return response_text
if response_text["flag"] is False and i == cycle_count - 1:
raise RuntimeWarning(response_text)
time.sleep(3)
return response_text
class Yuan_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="", system_prompt=None):
super().__init__(model_name=model_name, user=user_name)
self.history = []
self.api_key = api_key
self.system_prompt = system_prompt
self.input_prefix = ""
self.output_prefix = ""
def set_text_prefix(self, option, value):
if option == 'input_prefix':
self.input_prefix = value
elif option == 'output_prefix':
self.output_prefix = value
def get_answer_at_once(self):
# yuan temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
temperature = self.temperature if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
topP = self.top_p
topK = self.n_choices
# max_tokens should be in [1,200]
max_tokens = self.max_generation_token if self.max_generation_token is not None else 50
if max_tokens > 200:
max_tokens = 200
stop = self.stop_sequence if self.stop_sequence is not None else []
examples = []
system_prompt = self.system_prompt
if system_prompt is not None:
lines = system_prompt.splitlines()
# TODO: support prefixes in system prompt or settings
"""
if lines[0].startswith('-'):
prefixes = lines.pop()[1:].split('|')
self.input_prefix = prefixes[0]
if len(prefixes) > 1:
self.output_prefix = prefixes[1]
if len(prefixes) > 2:
stop = prefixes[2].split(',')
"""
for i in range(0, len(lines), 2):
in_line = lines[i]
out_line = lines[i + 1] if i + 1 < len(lines) else ""
examples.append((in_line, out_line))
yuan = Yuan(engine=self.model_name.replace('yuanai-1.0-', ''),
temperature=temperature,
max_tokens=max_tokens,
topK=topK,
topP=topP,
input_prefix=self.input_prefix,
input_suffix="",
output_prefix=self.output_prefix,
output_suffix="".join(stop),
)
if not self.api_key:
return NO_APIKEY_MSG, 0
yuan.set_account(self.api_key)
for in_line, out_line in examples:
yuan.add_example(Example(inp=in_line, out=out_line))
prompt = self.history[-1]["content"]
answer = yuan.submit_API(prompt, trun=stop)
return answer, len(answer)