liuhua liuhua commited on
Commit
e6abe77
·
1 Parent(s): 7362294

SparkTTS (#2535)

Browse files

### What problem does this PR solve?

SparkTTS

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>

api/apps/llm_app.py CHANGED
@@ -161,7 +161,10 @@ def add_llm():
161
 
162
  elif factory =="XunFei Spark":
163
  llm_name = req["llm_name"]
164
- api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
 
 
 
165
 
166
  elif factory == "BaiduYiyan":
167
  llm_name = req["llm_name"]
 
161
 
162
  elif factory =="XunFei Spark":
163
  llm_name = req["llm_name"]
164
+ if req["model_type"] == "chat":
165
+ api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
166
+ elif req["model_type"] == "tts":
167
+ api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])
168
 
169
  elif factory == "BaiduYiyan":
170
  llm_name = req["llm_name"]
rag/llm/__init__.py CHANGED
@@ -139,5 +139,6 @@ Seq2txtModel = {
139
  TTSModel = {
140
  "Fish Audio": FishAudioTTS,
141
  "Tongyi-Qianwen": QwenTTS,
142
- "OpenAI":OpenAITTS
 
143
  }
 
139
  TTSModel = {
140
  "Fish Audio": FishAudioTTS,
141
  "Tongyi-Qianwen": QwenTTS,
142
+ "OpenAI":OpenAITTS,
143
+ "XunFei Spark":SparkTTS
144
  }
rag/llm/tts_model.py CHANGED
@@ -14,16 +14,30 @@
14
  # limitations under the License.
15
  #
16
 
17
- import requests
18
- from typing import Annotated, Literal
 
 
 
 
 
 
 
 
19
  from abc import ABC
 
 
 
 
 
 
20
  import httpx
21
  import ormsgpack
 
 
22
  from pydantic import BaseModel, conint
 
23
  from rag.utils import num_tokens_from_string
24
- import json
25
- import re
26
- import time
27
 
28
 
29
  class ServeReferenceAudio(BaseModel):
@@ -161,7 +175,7 @@ class QwenTTS(Base):
161
 
162
  class OpenAITTS(Base):
163
  def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
164
- if not base_url: base_url="https://api.openai.com/v1"
165
  self.api_key = key
166
  self.model_name = model_name
167
  self.base_url = base_url
@@ -185,3 +199,101 @@ class OpenAITTS(Base):
185
  for chunk in response.iter_content():
186
  if chunk:
187
  yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # limitations under the License.
15
  #
16
 
17
+ import _thread as thread
18
+ import base64
19
+ import datetime
20
+ import hashlib
21
+ import hmac
22
+ import json
23
+ import queue
24
+ import re
25
+ import ssl
26
+ import time
27
  from abc import ABC
28
+ from datetime import datetime
29
+ from time import mktime
30
+ from typing import Annotated, Literal
31
+ from urllib.parse import urlencode
32
+ from wsgiref.handlers import format_date_time
33
+
34
  import httpx
35
  import ormsgpack
36
+ import requests
37
+ import websocket
38
  from pydantic import BaseModel, conint
39
+
40
  from rag.utils import num_tokens_from_string
 
 
 
41
 
42
 
43
  class ServeReferenceAudio(BaseModel):
 
175
 
176
  class OpenAITTS(Base):
177
  def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
178
+ if not base_url: base_url = "https://api.openai.com/v1"
179
  self.api_key = key
180
  self.model_name = model_name
181
  self.base_url = base_url
 
199
  for chunk in response.iter_content():
200
  if chunk:
201
  yield chunk
202
+
203
+
204
+ class SparkTTS:
205
+ STATUS_FIRST_FRAME = 0
206
+ STATUS_CONTINUE_FRAME = 1
207
+ STATUS_LAST_FRAME = 2
208
+
209
+ def __init__(self, key, model_name, base_url=""):
210
+ key = json.loads(key)
211
+ self.APPID = key.get("spark_app_id", "xxxxxxx")
212
+ self.APISecret = key.get("spark_api_secret", "xxxxxxx")
213
+ self.APIKey = key.get("spark_api_key", "xxxxxx")
214
+ self.model_name = model_name
215
+ self.CommonArgs = {"app_id": self.APPID}
216
+ self.audio_queue = queue.Queue()
217
+
218
+ # 用来存储音频数据
219
+
220
+ # 生成url
221
+ def create_url(self):
222
+ url = 'wss://tts-api.xfyun.cn/v2/tts'
223
+ now = datetime.now()
224
+ date = format_date_time(mktime(now.timetuple()))
225
+ signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
226
+ signature_origin += "date: " + date + "\n"
227
+ signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
228
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
229
+ digestmod=hashlib.sha256).digest()
230
+ signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
231
+ authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
232
+ self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
233
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
234
+ v = {
235
+ "authorization": authorization,
236
+ "date": date,
237
+ "host": "ws-api.xfyun.cn"
238
+ }
239
+ url = url + '?' + urlencode(v)
240
+ return url
241
+
242
+ def tts(self, text):
243
+ BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
244
+ Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
245
+ CommonArgs = {"app_id": self.APPID}
246
+ audio_queue = self.audio_queue
247
+ model_name = self.model_name
248
+
249
+ class Callback:
250
+ def __init__(self):
251
+ self.audio_queue = audio_queue
252
+
253
+ def on_message(self, ws, message):
254
+ message = json.loads(message)
255
+ code = message["code"]
256
+ sid = message["sid"]
257
+ audio = message["data"]["audio"]
258
+ audio = base64.b64decode(audio)
259
+ status = message["data"]["status"]
260
+ if status == 2:
261
+ ws.close()
262
+ if code != 0:
263
+ errMsg = message["message"]
264
+ raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
265
+ else:
266
+ self.audio_queue.put(audio)
267
+
268
+ def on_error(self, ws, error):
269
+ raise Exception(error)
270
+
271
+ def on_close(self, ws, close_status_code, close_msg):
272
+ self.audio_queue.put(None) # 放入 None 作为结束标志
273
+
274
+ def on_open(self, ws):
275
+ def run(*args):
276
+ d = {"common": CommonArgs,
277
+ "business": BusinessArgs,
278
+ "data": Data}
279
+ ws.send(json.dumps(d))
280
+
281
+ thread.start_new_thread(run, ())
282
+
283
+ wsUrl = self.create_url()
284
+ websocket.enableTrace(False)
285
+ a = Callback()
286
+ ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
287
+ on_message=a.on_message)
288
+ status_code = 0
289
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
290
+ while True:
291
+ audio_chunk = self.audio_queue.get()
292
+ if audio_chunk is None:
293
+ if status_code == 0:
294
+ raise Exception(
295
+ f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
296
+ else:
297
+ break
298
+ status_code = 1
299
+ yield audio_chunk
requirements.txt CHANGED
@@ -94,6 +94,8 @@ vertexai==1.64.0
94
  volcengine==1.0.146
95
  voyageai==0.2.3
96
  webdriver_manager==4.0.1
 
 
97
  Werkzeug==3.0.3
98
  wikipedia==1.4.0
99
  word2number==1.1
 
94
  volcengine==1.0.146
95
  voyageai==0.2.3
96
  webdriver_manager==4.0.1
97
+ websocket==0.2.1
98
+ websocket-client==1.8.0
99
  Werkzeug==3.0.3
100
  wikipedia==1.4.0
101
  word2number==1.1
web/src/locales/en.ts CHANGED
@@ -551,6 +551,12 @@ The above is the content you need to summarize.`,
551
  SparkModelNameMessage: 'Please select Spark model',
552
  addSparkAPIPassword: 'Spark APIPassword',
553
  SparkAPIPasswordMessage: 'please input your APIPassword',
 
 
 
 
 
 
554
  yiyanModelNameMessage: 'Please input model name',
555
  addyiyanAK: 'yiyan API KEY',
556
  yiyanAKMessage: 'Please input your API KEY',
 
551
  SparkModelNameMessage: 'Please select Spark model',
552
  addSparkAPIPassword: 'Spark APIPassword',
553
  SparkAPIPasswordMessage: 'please input your APIPassword',
554
+ addSparkAPPID: 'Spark APPID',
555
+ SparkAPPIDMessage: 'please input your APPID',
556
+ addSparkAPISecret: 'Spark APISecret',
557
+ SparkAPISecretMessage: 'please input your APISecret',
558
+ addSparkAPIKey: 'Spark APIKey',
559
+ SparkAPIKeyMessage: 'please input your APIKey',
560
  yiyanModelNameMessage: 'Please input model name',
561
  addyiyanAK: 'yiyan API KEY',
562
  yiyanAKMessage: 'Please input your API KEY',
web/src/locales/zh-traditional.ts CHANGED
@@ -512,6 +512,12 @@ export default {
512
  SparkModelNameMessage: '請選擇星火模型!',
513
  addSparkAPIPassword: '星火 APIPassword',
514
  SparkAPIPasswordMessage: '請輸入 APIPassword',
 
 
 
 
 
 
515
  yiyanModelNameMessage: '輸入模型名稱',
516
  addyiyanAK: '一言 API KEY',
517
  yiyanAKMessage: '請輸入 API KEY',
 
512
  SparkModelNameMessage: '請選擇星火模型!',
513
  addSparkAPIPassword: '星火 APIPassword',
514
  SparkAPIPasswordMessage: '請輸入 APIPassword',
515
+ addSparkAPPID: '星火 APPID',
516
+ SparkAPPIDMessage: '請輸入 APPID',
517
+ addSparkAPISecret: '星火 APISecret',
518
+ SparkAPISecretMessage: '請輸入 APISecret',
519
+ addSparkAPIKey: '星火 APIKey',
520
+ SparkAPIKeyMessage: '請輸入 APIKey',
521
  yiyanModelNameMessage: '輸入模型名稱',
522
  addyiyanAK: '一言 API KEY',
523
  yiyanAKMessage: '請輸入 API KEY',
web/src/locales/zh.ts CHANGED
@@ -529,6 +529,12 @@ export default {
529
  SparkModelNameMessage: '请选择星火模型!',
530
  addSparkAPIPassword: '星火 APIPassword',
531
  SparkAPIPasswordMessage: '请输入 APIPassword',
 
 
 
 
 
 
532
  yiyanModelNameMessage: '请输入模型名称',
533
  addyiyanAK: '一言 API KEY',
534
  yiyanAKMessage: '请输入 API KEY',
 
529
  SparkModelNameMessage: '请选择星火模型!',
530
  addSparkAPIPassword: '星火 APIPassword',
531
  SparkAPIPasswordMessage: '请输入 APIPassword',
532
+ addSparkAPPID: '星火 APPID',
533
+ SparkAPPIDMessage: '请输入 APPID',
534
+ addSparkAPISecret: '星火 APISecret',
535
+ SparkAPISecretMessage: '请输入 APISecret',
536
+ addSparkAPIKey: '星火 APIKey',
537
+ SparkAPIKeyMessage: '请输入 APIKey',
538
  yiyanModelNameMessage: '请输入模型名称',
539
  addyiyanAK: '一言 API KEY',
540
  yiyanAKMessage: '请输入 API KEY',
web/src/pages/user-setting/setting-model/spark-modal/index.tsx CHANGED
@@ -7,6 +7,9 @@ import omit from 'lodash/omit';
7
  type FieldType = IAddLlmRequestBody & {
8
  vision: boolean;
9
  spark_api_password: string;
 
 
 
10
  };
11
 
12
  const { Option } = Select;
@@ -63,28 +66,67 @@ const SparkModal = ({
63
  >
64
  <Select placeholder={t('modelTypeMessage')}>
65
  <Option value="chat">chat</Option>
 
66
  </Select>
67
  </Form.Item>
68
  <Form.Item<FieldType>
69
  label={t('modelName')}
70
  name="llm_name"
71
- initialValue={'Spark-Max'}
72
  rules={[{ required: true, message: t('SparkModelNameMessage') }]}
73
  >
74
- <Select placeholder={t('modelTypeMessage')}>
75
- <Option value="Spark-Max">Spark-Max</Option>
76
- <Option value="Spark-Lite">Spark-Lite</Option>
77
- <Option value="Spark-Pro">Spark-Pro</Option>
78
- <Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
79
- <Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
80
- </Select>
81
  </Form.Item>
82
- <Form.Item<FieldType>
83
- label={t('addSparkAPIPassword')}
84
- name="spark_api_password"
85
- rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
86
- >
87
- <Input placeholder={t('SparkAPIPasswordMessage')} />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  </Form.Item>
89
  </Form>
90
  </Modal>
 
7
  type FieldType = IAddLlmRequestBody & {
8
  vision: boolean;
9
  spark_api_password: string;
10
+ spark_app_id: string;
11
+ spark_api_secret: string;
12
+ spark_api_key: string;
13
  };
14
 
15
  const { Option } = Select;
 
66
  >
67
  <Select placeholder={t('modelTypeMessage')}>
68
  <Option value="chat">chat</Option>
69
+ <Option value="tts">tts</Option>
70
  </Select>
71
  </Form.Item>
72
  <Form.Item<FieldType>
73
  label={t('modelName')}
74
  name="llm_name"
 
75
  rules={[{ required: true, message: t('SparkModelNameMessage') }]}
76
  >
77
+ <Input placeholder={t('modelNameMessage')} />
 
 
 
 
 
 
78
  </Form.Item>
79
+ <Form.Item noStyle dependencies={['model_type']}>
80
+ {({ getFieldValue }) =>
81
+ getFieldValue('model_type') === 'chat' && (
82
+ <Form.Item<FieldType>
83
+ label={t('addSparkAPIPassword')}
84
+ name="spark_api_password"
85
+ rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
86
+ >
87
+ <Input placeholder={t('SparkAPIPasswordMessage')} />
88
+ </Form.Item>
89
+ )
90
+ }
91
+ </Form.Item>
92
+ <Form.Item noStyle dependencies={['model_type']}>
93
+ {({ getFieldValue }) =>
94
+ getFieldValue('model_type') === 'tts' && (
95
+ <Form.Item<FieldType>
96
+ label={t('addSparkAPPID')}
97
+ name="spark_app_id"
98
+ rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
99
+ >
100
+ <Input placeholder={t('SparkAPPIDMessage')} />
101
+ </Form.Item>
102
+ )
103
+ }
104
+ </Form.Item>
105
+ <Form.Item noStyle dependencies={['model_type']}>
106
+ {({ getFieldValue }) =>
107
+ getFieldValue('model_type') === 'tts' && (
108
+ <Form.Item<FieldType>
109
+ label={t('addSparkAPISecret')}
110
+ name="spark_api_secret"
111
+ rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
112
+ >
113
+ <Input placeholder={t('SparkAPISecretMessage')} />
114
+ </Form.Item>
115
+ )
116
+ }
117
+ </Form.Item>
118
+ <Form.Item noStyle dependencies={['model_type']}>
119
+ {({ getFieldValue }) =>
120
+ getFieldValue('model_type') === 'tts' && (
121
+ <Form.Item<FieldType>
122
+ label={t('addSparkAPIKey')}
123
+ name="spark_api_key"
124
+ rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
125
+ >
126
+ <Input placeholder={t('SparkAPIKeyMessage')} />
127
+ </Form.Item>
128
+ )
129
+ }
130
  </Form.Item>
131
  </Form>
132
  </Modal>