|
|
|
|
|
import requests, json |
|
from bot.bot import Bot |
|
from bot.session_manager import SessionManager |
|
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession |
|
from bridge.context import ContextType, Context |
|
from bridge.reply import Reply, ReplyType |
|
from common.log import logger |
|
from config import conf |
|
from common import const |
|
import time |
|
import _thread as thread |
|
import datetime |
|
from datetime import datetime |
|
from wsgiref.handlers import format_date_time |
|
from urllib.parse import urlencode |
|
import base64 |
|
import ssl |
|
import hashlib |
|
import hmac |
|
import json |
|
from time import mktime |
|
from urllib.parse import urlparse |
|
import websocket |
|
import queue |
|
import threading |
|
import random |
|
|
|
|
|
queue_map = dict() |
|
|
|
|
|
reply_map = dict() |
|
|
|
|
|
class XunFeiBot(Bot): |
|
def __init__(self): |
|
super().__init__() |
|
self.app_id = conf().get("xunfei_app_id") |
|
self.api_key = conf().get("xunfei_api_key") |
|
self.api_secret = conf().get("xunfei_api_secret") |
|
|
|
|
|
|
|
self.domain = "generalv3" |
|
|
|
|
|
|
|
self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" |
|
self.host = urlparse(self.spark_url).netloc |
|
self.path = urlparse(self.spark_url).path |
|
|
|
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI) |
|
|
|
def reply(self, query, context: Context = None) -> Reply: |
|
if context.type == ContextType.TEXT: |
|
logger.info("[XunFei] query={}".format(query)) |
|
session_id = context["session_id"] |
|
request_id = self.gen_request_id(session_id) |
|
reply_map[request_id] = "" |
|
session = self.sessions.session_query(query, session_id) |
|
threading.Thread(target=self.create_web_socket, |
|
args=(session.messages, request_id)).start() |
|
depth = 0 |
|
time.sleep(0.1) |
|
t1 = time.time() |
|
usage = {} |
|
while depth <= 300: |
|
try: |
|
data_queue = queue_map.get(request_id) |
|
if not data_queue: |
|
depth += 1 |
|
time.sleep(0.1) |
|
continue |
|
data_item = data_queue.get(block=True, timeout=0.1) |
|
if data_item.is_end: |
|
|
|
del queue_map[request_id] |
|
if data_item.reply: |
|
reply_map[request_id] += data_item.reply |
|
usage = data_item.usage |
|
break |
|
|
|
reply_map[request_id] += data_item.reply |
|
depth += 1 |
|
except Exception as e: |
|
depth += 1 |
|
continue |
|
t2 = time.time() |
|
logger.info( |
|
f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}" |
|
) |
|
self.sessions.session_reply(reply_map[request_id], session_id, |
|
usage.get("total_tokens")) |
|
reply = Reply(ReplyType.TEXT, reply_map[request_id]) |
|
del reply_map[request_id] |
|
return reply |
|
else: |
|
reply = Reply(ReplyType.ERROR, |
|
"Bot不支持处理{}类型的消息".format(context.type)) |
|
return reply |
|
|
|
def create_web_socket(self, prompt, session_id, temperature=0.5): |
|
logger.info(f"[XunFei] start connect, prompt={prompt}") |
|
websocket.enableTrace(False) |
|
wsUrl = self.create_url() |
|
ws = websocket.WebSocketApp(wsUrl, |
|
on_message=on_message, |
|
on_error=on_error, |
|
on_close=on_close, |
|
on_open=on_open) |
|
data_queue = queue.Queue(1000) |
|
queue_map[session_id] = data_queue |
|
ws.appid = self.app_id |
|
ws.question = prompt |
|
ws.domain = self.domain |
|
ws.session_id = session_id |
|
ws.temperature = temperature |
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) |
|
|
|
def gen_request_id(self, session_id: str): |
|
return session_id + "_" + str(int(time.time())) + "" + str( |
|
random.randint(0, 100)) |
|
|
|
|
|
def create_url(self): |
|
|
|
now = datetime.now() |
|
date = format_date_time(mktime(now.timetuple())) |
|
|
|
|
|
signature_origin = "host: " + self.host + "\n" |
|
signature_origin += "date: " + date + "\n" |
|
signature_origin += "GET " + self.path + " HTTP/1.1" |
|
|
|
|
|
signature_sha = hmac.new(self.api_secret.encode('utf-8'), |
|
signature_origin.encode('utf-8'), |
|
digestmod=hashlib.sha256).digest() |
|
|
|
signature_sha_base64 = base64.b64encode(signature_sha).decode( |
|
encoding='utf-8') |
|
|
|
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \ |
|
f'signature="{signature_sha_base64}"' |
|
|
|
authorization = base64.b64encode( |
|
authorization_origin.encode('utf-8')).decode(encoding='utf-8') |
|
|
|
|
|
v = {"authorization": authorization, "date": date, "host": self.host} |
|
|
|
url = self.spark_url + '?' + urlencode(v) |
|
|
|
return url |
|
|
|
def gen_params(self, appid, domain, question): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": { |
|
"app_id": appid, |
|
"uid": "1234" |
|
}, |
|
"parameter": { |
|
"chat": { |
|
"domain": domain, |
|
"random_threshold": 0.5, |
|
"max_tokens": 2048, |
|
"auditing": "default" |
|
} |
|
}, |
|
"payload": { |
|
"message": { |
|
"text": question |
|
} |
|
} |
|
} |
|
return data |
|
|
|
|
|
class ReplyItem: |
|
def __init__(self, reply, usage=None, is_end=False): |
|
self.is_end = is_end |
|
self.reply = reply |
|
self.usage = usage |
|
|
|
|
|
|
|
def on_error(ws, error): |
|
logger.error(f"[XunFei] error: {str(error)}") |
|
|
|
|
|
|
|
def on_close(ws, one, two): |
|
data_queue = queue_map.get(ws.session_id) |
|
data_queue.put("END") |
|
|
|
|
|
|
|
def on_open(ws): |
|
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}") |
|
thread.start_new_thread(run, (ws, )) |
|
|
|
|
|
def run(ws, *args): |
|
data = json.dumps( |
|
gen_params(appid=ws.appid, |
|
domain=ws.domain, |
|
question=ws.question, |
|
temperature=ws.temperature)) |
|
ws.send(data) |
|
|
|
|
|
|
|
|
|
def on_message(ws, message): |
|
data = json.loads(message) |
|
code = data['header']['code'] |
|
if code != 0: |
|
logger.error(f'请求错误: {code}, {data}') |
|
ws.close() |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
content = choices["text"][0]["content"] |
|
data_queue = queue_map.get(ws.session_id) |
|
if not data_queue: |
|
logger.error( |
|
f"[XunFei] can't find data queue, session_id={ws.session_id}") |
|
return |
|
reply_item = ReplyItem(content) |
|
if status == 2: |
|
usage = data["payload"].get("usage") |
|
reply_item = ReplyItem(content, usage) |
|
reply_item.is_end = True |
|
ws.close() |
|
data_queue.put(reply_item) |
|
|
|
|
|
def gen_params(appid, domain, question, temperature=0.5): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": { |
|
"app_id": appid, |
|
"uid": "1234" |
|
}, |
|
"parameter": { |
|
"chat": { |
|
"domain": domain, |
|
"temperature": temperature, |
|
"random_threshold": 0.5, |
|
"max_tokens": 2048, |
|
"auditing": "default" |
|
} |
|
}, |
|
"payload": { |
|
"message": { |
|
"text": question |
|
} |
|
} |
|
} |
|
return data |
|
|