Spaces:
Sleeping
Sleeping
# 代码主要来源于 https://github.com/Shawn-Inspur/Yuan-1.0/blob/main/yuan_api/inspurai.py | |
import hashlib | |
import json | |
import os | |
import time | |
import uuid | |
from datetime import datetime | |
import pytz | |
import requests | |
from modules.presets import NO_APIKEY_MSG | |
from modules.models.base_model import BaseLLMModel | |
class Example: | |
""" store some examples(input, output pairs and formats) for few-shots to prime the model.""" | |
def __init__(self, inp, out): | |
self.input = inp | |
self.output = out | |
self.id = uuid.uuid4().hex | |
def get_input(self): | |
"""return the input of the example.""" | |
return self.input | |
def get_output(self): | |
"""Return the output of the example.""" | |
return self.output | |
def get_id(self): | |
"""Returns the unique ID of the example.""" | |
return self.id | |
def as_dict(self): | |
return { | |
"input": self.get_input(), | |
"output": self.get_output(), | |
"id": self.get_id(), | |
} | |
class Yuan: | |
"""The main class for a user to interface with the Inspur Yuan API. | |
A user can set account info and add examples of the API request. | |
""" | |
def __init__(self, | |
engine='base_10B', | |
temperature=0.9, | |
max_tokens=100, | |
input_prefix='', | |
input_suffix='\n', | |
output_prefix='答:', | |
output_suffix='\n\n', | |
append_output_prefix_to_query=False, | |
topK=1, | |
topP=0.9, | |
frequencyPenalty=1.2, | |
responsePenalty=1.2, | |
noRepeatNgramSize=2): | |
self.examples = {} | |
self.engine = engine | |
self.temperature = temperature | |
self.max_tokens = max_tokens | |
self.topK = topK | |
self.topP = topP | |
self.frequencyPenalty = frequencyPenalty | |
self.responsePenalty = responsePenalty | |
self.noRepeatNgramSize = noRepeatNgramSize | |
self.input_prefix = input_prefix | |
self.input_suffix = input_suffix | |
self.output_prefix = output_prefix | |
self.output_suffix = output_suffix | |
self.append_output_prefix_to_query = append_output_prefix_to_query | |
self.stop = (output_suffix + input_prefix).strip() | |
self.api = None | |
# if self.engine not in ['base_10B','translate','dialog']: | |
# raise Exception('engine must be one of [\'base_10B\',\'translate\',\'dialog\'] ') | |
def set_account(self, api_key): | |
account = api_key.split('||') | |
self.api = YuanAPI(user=account[0], phone=account[1]) | |
def add_example(self, ex): | |
"""Add an example to the object. | |
Example must be an instance of the Example class.""" | |
assert isinstance(ex, Example), "Please create an Example object." | |
self.examples[ex.get_id()] = ex | |
def delete_example(self, id): | |
"""Delete example with the specific id.""" | |
if id in self.examples: | |
del self.examples[id] | |
def get_example(self, id): | |
"""Get a single example.""" | |
return self.examples.get(id, None) | |
def get_all_examples(self): | |
"""Returns all examples as a list of dicts.""" | |
return {k: v.as_dict() for k, v in self.examples.items()} | |
def get_prime_text(self): | |
"""Formats all examples to prime the model.""" | |
return "".join( | |
[self.format_example(ex) for ex in self.examples.values()]) | |
def get_engine(self): | |
"""Returns the engine specified for the API.""" | |
return self.engine | |
def get_temperature(self): | |
"""Returns the temperature specified for the API.""" | |
return self.temperature | |
def get_max_tokens(self): | |
"""Returns the max tokens specified for the API.""" | |
return self.max_tokens | |
def craft_query(self, prompt): | |
"""Creates the query for the API request.""" | |
q = self.get_prime_text( | |
) + self.input_prefix + prompt + self.input_suffix | |
if self.append_output_prefix_to_query: | |
q = q + self.output_prefix | |
return q | |
def format_example(self, ex): | |
"""Formats the input, output pair.""" | |
return self.input_prefix + ex.get_input( | |
) + self.input_suffix + self.output_prefix + ex.get_output( | |
) + self.output_suffix | |
def response(self, | |
query, | |
engine='base_10B', | |
max_tokens=20, | |
temperature=0.9, | |
topP=0.1, | |
topK=1, | |
frequencyPenalty=1.0, | |
responsePenalty=1.0, | |
noRepeatNgramSize=0): | |
"""Obtains the original result returned by the API.""" | |
if self.api is None: | |
return NO_APIKEY_MSG | |
try: | |
# requestId = submit_request(query,temperature,topP,topK,max_tokens, engine) | |
requestId = self.api.submit_request(query, temperature, topP, topK, max_tokens, engine, frequencyPenalty, | |
responsePenalty, noRepeatNgramSize) | |
response_text = self.api.reply_request(requestId) | |
except Exception as e: | |
raise e | |
return response_text | |
def del_special_chars(self, msg): | |
special_chars = ['<unk>', '<eod>', '#', '▃', '▁', '▂', ' '] | |
for char in special_chars: | |
msg = msg.replace(char, '') | |
return msg | |
def submit_API(self, prompt, trun=[]): | |
"""Submit prompt to yuan API interface and obtain an pure text reply. | |
:prompt: Question or any content a user may input. | |
:return: pure text response.""" | |
query = self.craft_query(prompt) | |
res = self.response(query, engine=self.engine, | |
max_tokens=self.max_tokens, | |
temperature=self.temperature, | |
topP=self.topP, | |
topK=self.topK, | |
frequencyPenalty=self.frequencyPenalty, | |
responsePenalty=self.responsePenalty, | |
noRepeatNgramSize=self.noRepeatNgramSize) | |
if 'resData' in res and res['resData'] != None: | |
txt = res['resData'] | |
else: | |
txt = '模型返回为空,请尝试修改输入' | |
# 单独针对翻译模型的后处理 | |
if self.engine == 'translate': | |
txt = txt.replace(' ##', '').replace(' "', '"').replace(": ", ":").replace(" ,", ",") \ | |
.replace('英文:', '').replace('文:', '').replace("( ", "(").replace(" )", ")") | |
else: | |
txt = txt.replace(' ', '') | |
txt = self.del_special_chars(txt) | |
# trun多结束符截断模型输出 | |
if isinstance(trun, str): | |
trun = [trun] | |
try: | |
if trun != None and isinstance(trun, list) and trun != []: | |
for tr in trun: | |
if tr in txt and tr != "": | |
txt = txt[:txt.index(tr)] | |
else: | |
continue | |
except: | |
return txt | |
return txt | |
class YuanAPI: | |
ACCOUNT = '' | |
PHONE = '' | |
SUBMIT_URL = "http://api.airyuan.cn:32102/v1/interface/api/infer/getRequestId?" | |
REPLY_URL = "http://api.airyuan.cn:32102/v1/interface/api/result?" | |
def __init__(self, user, phone): | |
self.ACCOUNT = user | |
self.PHONE = phone | |
def code_md5(str): | |
code = str.encode("utf-8") | |
m = hashlib.md5() | |
m.update(code) | |
result = m.hexdigest() | |
return result | |
def rest_get(url, header, timeout, show_error=False): | |
'''Call rest get method''' | |
try: | |
response = requests.get(url, headers=header, timeout=timeout, verify=False) | |
return response | |
except Exception as exception: | |
if show_error: | |
print(exception) | |
return None | |
def header_generation(self): | |
"""Generate header for API request.""" | |
t = datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d") | |
token = self.code_md5(self.ACCOUNT + self.PHONE + t) | |
headers = {'token': token} | |
return headers | |
def submit_request(self, query, temperature, topP, topK, max_tokens, engine, frequencyPenalty, responsePenalty, | |
noRepeatNgramSize): | |
"""Submit query to the backend server and get requestID.""" | |
headers = self.header_generation() | |
# url=SUBMIT_URL + "account={0}&data={1}&temperature={2}&topP={3}&topK={4}&tokensToGenerate={5}&type={6}".format(ACCOUNT,query,temperature,topP,topK,max_tokens,"api") | |
# url=SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \ | |
# "&type={7}".format(engine,ACCOUNT,query,temperature,topP,topK, max_tokens,"api") | |
url = self.SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \ | |
"&type={7}&frequencyPenalty={8}&responsePenalty={9}&noRepeatNgramSize={10}". \ | |
format(engine, self.ACCOUNT, query, temperature, topP, topK, max_tokens, "api", frequencyPenalty, | |
responsePenalty, noRepeatNgramSize) | |
response = self.rest_get(url, headers, 30) | |
response_text = json.loads(response.text) | |
if response_text["flag"]: | |
requestId = response_text["resData"] | |
return requestId | |
else: | |
raise RuntimeWarning(response_text) | |
def reply_request(self, requestId, cycle_count=5): | |
"""Check reply API to get the inference response.""" | |
url = self.REPLY_URL + "account={0}&requestId={1}".format(self.ACCOUNT, requestId) | |
headers = self.header_generation() | |
response_text = {"flag": True, "resData": None} | |
for i in range(cycle_count): | |
response = self.rest_get(url, headers, 30, show_error=True) | |
response_text = json.loads(response.text) | |
if response_text["resData"] is not None: | |
return response_text | |
if response_text["flag"] is False and i == cycle_count - 1: | |
raise RuntimeWarning(response_text) | |
time.sleep(3) | |
return response_text | |
class Yuan_Client(BaseLLMModel): | |
def __init__(self, model_name, api_key, user_name="", system_prompt=None): | |
super().__init__(model_name=model_name, user=user_name) | |
self.history = [] | |
self.api_key = api_key | |
self.system_prompt = system_prompt | |
self.input_prefix = "" | |
self.output_prefix = "" | |
def set_text_prefix(self, option, value): | |
if option == 'input_prefix': | |
self.input_prefix = value | |
elif option == 'output_prefix': | |
self.output_prefix = value | |
def get_answer_at_once(self): | |
# yuan temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert | |
temperature = self.temperature if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 | |
topP = self.top_p | |
topK = self.n_choices | |
# max_tokens should be in [1,200] | |
max_tokens = self.max_generation_token if self.max_generation_token is not None else 50 | |
if max_tokens > 200: | |
max_tokens = 200 | |
stop = self.stop_sequence if self.stop_sequence is not None else [] | |
examples = [] | |
system_prompt = self.system_prompt | |
if system_prompt is not None: | |
lines = system_prompt.splitlines() | |
# TODO: support prefixes in system prompt or settings | |
""" | |
if lines[0].startswith('-'): | |
prefixes = lines.pop()[1:].split('|') | |
self.input_prefix = prefixes[0] | |
if len(prefixes) > 1: | |
self.output_prefix = prefixes[1] | |
if len(prefixes) > 2: | |
stop = prefixes[2].split(',') | |
""" | |
for i in range(0, len(lines), 2): | |
in_line = lines[i] | |
out_line = lines[i + 1] if i + 1 < len(lines) else "" | |
examples.append((in_line, out_line)) | |
yuan = Yuan(engine=self.model_name.replace('yuanai-1.0-', ''), | |
temperature=temperature, | |
max_tokens=max_tokens, | |
topK=topK, | |
topP=topP, | |
input_prefix=self.input_prefix, | |
input_suffix="", | |
output_prefix=self.output_prefix, | |
output_suffix="".join(stop), | |
) | |
if not self.api_key: | |
return NO_APIKEY_MSG, 0 | |
yuan.set_account(self.api_key) | |
for in_line, out_line in examples: | |
yuan.add_example(Example(inp=in_line, out=out_line)) | |
prompt = self.history[-1]["content"] | |
answer = yuan.submit_API(prompt, trun=stop) | |
return answer, len(answer) | |