|
|
|
|
|
''' |
|
@File : wenxin_llm.py |
|
@Time : 2023/10/16 18:53:26 |
|
@Author : Logan Zou |
|
@Version : 1.0 |
|
@Contact : loganzou0421@163.com |
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
@Desc : 基于讯飞星火大模型自定义 LLM 类 |
|
''' |
|
|
|
from langchain.llms.base import LLM |
|
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple |
|
from pydantic import Field |
|
from llm.self_llm import Self_LLM |
|
import json |
|
import requests |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
import _thread as thread |
|
import base64 |
|
import datetime |
|
import hashlib |
|
import hmac |
|
import json |
|
from urllib.parse import urlparse |
|
import ssl |
|
from datetime import datetime |
|
from time import mktime |
|
from urllib.parse import urlencode |
|
from wsgiref.handlers import format_date_time |
|
import websocket |
|
import queue |
|
|
|
class Spark_LLM(Self_LLM): |
|
|
|
|
|
url : str = "ws://spark-api.xf-yun.com/v1.1/chat" |
|
|
|
appid : str = None |
|
|
|
api_secret : str = None |
|
|
|
domain :str = "general" |
|
|
|
max_tokens : int = 4096 |
|
|
|
def getText(self, role, content, text = []): |
|
|
|
jsoncon = {} |
|
jsoncon["role"] = role |
|
jsoncon["content"] = content |
|
text.append(jsoncon) |
|
return text |
|
|
|
def _call(self, prompt : str, stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any): |
|
if self.api_key == None or self.appid == None or self.api_secret == None: |
|
|
|
print("请填入 Key") |
|
raise ValueError("Key 不存在") |
|
|
|
question = self.getText("user", prompt) |
|
|
|
try: |
|
response = spark_main(self.appid,self.api_key,self.api_secret,self.url,self.domain,question, self.temperature, self.max_tokens) |
|
return response |
|
except Exception as e: |
|
print(e) |
|
print("请求失败") |
|
return "请求失败" |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "Spark" |
|
|
|
answer = "" |
|
|
|
class Ws_Param(object): |
|
|
|
def __init__(self, APPID, APIKey, APISecret, Spark_url): |
|
self.APPID = APPID |
|
self.APIKey = APIKey |
|
self.APISecret = APISecret |
|
self.host = urlparse(Spark_url).netloc |
|
self.path = urlparse(Spark_url).path |
|
self.Spark_url = Spark_url |
|
|
|
self.temperature = 0 |
|
self.max_tokens = 2048 |
|
|
|
|
|
def create_url(self): |
|
|
|
now = datetime.now() |
|
date = format_date_time(mktime(now.timetuple())) |
|
|
|
|
|
signature_origin = "host: " + self.host + "\n" |
|
signature_origin += "date: " + date + "\n" |
|
signature_origin += "GET " + self.path + " HTTP/1.1" |
|
|
|
|
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), |
|
digestmod=hashlib.sha256).digest() |
|
|
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') |
|
|
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' |
|
|
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') |
|
|
|
|
|
v = { |
|
"authorization": authorization, |
|
"date": date, |
|
"host": self.host |
|
} |
|
|
|
url = self.Spark_url + '?' + urlencode(v) |
|
|
|
return url |
|
|
|
|
|
|
|
def on_error(ws, error): |
|
print("### error:", error) |
|
|
|
|
|
|
|
def on_close(ws,one,two): |
|
print(" ") |
|
|
|
|
|
|
|
def on_open(ws): |
|
thread.start_new_thread(run, (ws,)) |
|
|
|
|
|
def run(ws, *args): |
|
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens)) |
|
ws.send(data) |
|
|
|
|
|
|
|
def on_message(ws, message): |
|
|
|
data = json.loads(message) |
|
code = data['header']['code'] |
|
if code != 0: |
|
print(f'请求错误: {code}, {data}') |
|
ws.close() |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
content = choices["text"][0]["content"] |
|
print(content,end ="") |
|
global answer |
|
answer += content |
|
|
|
if status == 2: |
|
ws.close() |
|
|
|
|
|
def gen_params(appid, domain,question, temperature, max_tokens): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": { |
|
"app_id": appid, |
|
"uid": "1234" |
|
}, |
|
"parameter": { |
|
"chat": { |
|
"domain": domain, |
|
"random_threshold": 0.5, |
|
"max_tokens": max_tokens, |
|
"temperature" : temperature, |
|
"auditing": "default" |
|
} |
|
}, |
|
"payload": { |
|
"message": { |
|
"text": question |
|
} |
|
} |
|
} |
|
return data |
|
|
|
|
|
def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens): |
|
|
|
output_queue = queue.Queue() |
|
def on_message(ws, message): |
|
data = json.loads(message) |
|
code = data['header']['code'] |
|
if code != 0: |
|
print(f'请求错误: {code}, {data}') |
|
ws.close() |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
content = choices["text"][0]["content"] |
|
|
|
|
|
output_queue.put(content) |
|
if status == 2: |
|
ws.close() |
|
|
|
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url) |
|
websocket.enableTrace(False) |
|
wsUrl = wsParam.create_url() |
|
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) |
|
ws.appid = appid |
|
ws.question = question |
|
ws.domain = domain |
|
ws.temperature = temperature |
|
ws.max_tokens = max_tokens |
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) |
|
return ''.join([output_queue.get() for _ in range(output_queue.qsize())]) |
|
|
|
|
|
|
|
|
|
|
|
|