|
import _thread as thread |
|
import base64 |
|
import datetime |
|
import hashlib |
|
import hmac |
|
import json |
|
from urllib.parse import urlparse |
|
import ssl |
|
from datetime import datetime |
|
from time import mktime |
|
from urllib.parse import urlencode |
|
from wsgiref.handlers import format_date_time |
|
|
|
import websocket |
|
answer = "" |
|
appid = None |
|
api_secret = None |
|
api_key = None |
|
|
|
class Ws_Param(object): |
|
|
|
def __init__(self, APPID, APIKey, APISecret, Spark_url): |
|
self.APPID = APPID |
|
self.APIKey = APIKey |
|
self.APISecret = APISecret |
|
self.host = urlparse(Spark_url).netloc |
|
self.path = urlparse(Spark_url).path |
|
self.Spark_url = Spark_url |
|
|
|
|
|
def create_url(self): |
|
|
|
now = datetime.now() |
|
date = format_date_time(mktime(now.timetuple())) |
|
|
|
|
|
signature_origin = "host: " + self.host + "\n" |
|
signature_origin += "date: " + date + "\n" |
|
signature_origin += "GET " + self.path + " HTTP/1.1" |
|
|
|
|
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), |
|
digestmod=hashlib.sha256).digest() |
|
|
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') |
|
|
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' |
|
|
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') |
|
|
|
|
|
v = { |
|
"authorization": authorization, |
|
"date": date, |
|
"host": self.host |
|
} |
|
|
|
url = self.Spark_url + '?' + urlencode(v) |
|
|
|
return url |
|
|
|
|
|
|
|
def on_error(ws, error): |
|
print("### error:", error) |
|
|
|
|
|
|
|
def on_close(ws,one,two): |
|
return |
|
|
|
|
|
|
|
|
|
def on_open(ws): |
|
thread.start_new_thread(run, (ws,)) |
|
|
|
|
|
def run(ws, *args): |
|
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question)) |
|
ws.send(data) |
|
|
|
|
|
|
|
def on_message(ws, message): |
|
|
|
data = json.loads(message) |
|
code = data['header']['code'] |
|
if code != 0: |
|
print(f'请求错误: {code}, {data}') |
|
ws.close() |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
content = choices["text"][0]["content"] |
|
|
|
global answer |
|
answer += content |
|
|
|
if status == 2: |
|
ws.close() |
|
|
|
|
|
def gen_params(appid, domain,question): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": { |
|
"app_id": appid, |
|
"uid": "1234" |
|
}, |
|
"parameter": { |
|
"chat": { |
|
"domain": domain, |
|
"temperature": 0.5, |
|
"max_tokens": 2048 |
|
} |
|
}, |
|
"payload": { |
|
"message": { |
|
"text": question |
|
} |
|
} |
|
} |
|
return data |
|
|
|
|
|
def main(appid, api_key, api_secret, Spark_url,domain, question): |
|
|
|
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url) |
|
websocket.enableTrace(False) |
|
wsUrl = wsParam.create_url() |
|
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) |
|
ws.appid = appid |
|
ws.question = question |
|
ws.domain = domain |
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) |
|
|
|
|
|
|