Crawford.Zhou commited on
Commit
4d5e4b2
1 Parent(s): f54871c

添加gpt问答接口

Browse files
Files changed (2) hide show
  1. app.py +16 -43
  2. requirements.txt +18 -0
app.py CHANGED
@@ -3,47 +3,6 @@ if sys.platform == "darwin":
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
 
5
  import logging
6
- import openai
7
-
8
- # openai.log = "debug"
9
- openai.api_key = "sk-"
10
- openai.api_base = "https://api.chatanywhere.com.cn/v1"
11
-
12
-
13
-
14
- # 非流式响应
15
- # completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world!"}])
16
- # print(completion.choices[0].message.content)
17
-
18
- def gpt_35_api_stream(key, messages: list):
19
- openai.api_key = "sk-" + key
20
- """为提供的对话消息创建新的回答 (流式传输)
21
-
22
- Args:
23
- messages (list): 完整的对话消息
24
- api_key (str): OpenAI API 密钥
25
-
26
- Returns:
27
- tuple: (results, error_desc)
28
- """
29
- try:
30
- response = openai.ChatCompletion.create(
31
- model='gpt-3.5-turbo',
32
- messages=messages,
33
- stream=True,
34
- )
35
- completion = {'role': '', 'content': ''}
36
- for event in response:
37
- if event['choices'][0]['finish_reason'] == 'stop':
38
- print(f'收到的完成数据: {completion}')
39
- break
40
- for delta_k, delta_v in event['choices'][0]['delta'].items():
41
- print(f'流响应数据: {delta_k} = {delta_v}')
42
- completion[delta_k] += delta_v
43
- messages.append(completion) # 直接在传入参数 messages 中追加消息
44
- return True, ''
45
- except Exception as err:
46
- return False, f'OpenAI API 异常: {err}'
47
 
48
  logging.getLogger("numba").setLevel(logging.WARNING)
49
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
@@ -65,6 +24,19 @@ from text.cleaner import clean_text
65
  import gradio as gr
66
  import webbrowser
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  net_g = None
70
 
@@ -93,7 +65,8 @@ def get_text(text, language_str, hps):
93
 
94
  def infer(text, key, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
95
  global net_g
96
- message = gpt_35_api_stream(key, text)
 
97
  bert, phones, tones, lang_ids = get_text(message, "ZH", hps)
98
  with torch.no_grad():
99
  x_tst=phones.to(device).unsqueeze(0)
@@ -166,7 +139,7 @@ if __name__ == "__main__":
166
  """)
167
  text = gr.TextArea(label="Text", placeholder="Input Text Here",
168
  value="虚拟主播是什么?")
169
- text = gr.TextArea(label="Key", placeholder="请输入上面提示中获取的gpt key",
170
  value="key")
171
  speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker')
172
  sdp_ratio = gr.Slider(minimum=0.1, maximum=1, value=0.2, step=0.01, label='SDP/DP混合比')
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
 
5
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  logging.getLogger("numba").setLevel(logging.WARNING)
8
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
 
24
  import gradio as gr
25
  import webbrowser
26
 
27
+ import openai
28
+
29
+ # openai.log = "debug"
30
+ openai.api_base = "https://api.chatanywhere.com.cn/v1"
31
+
32
+
33
+ # 非流式响应
34
+
35
+ def gpt_35_api(gptkey, message):
36
+ openai.api_key = "sk-" + gptkey
37
+ completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": message}])
38
+ return completion.choices[0].message.content
39
+
40
 
41
  net_g = None
42
 
 
65
 
66
  def infer(text, key, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
67
  global net_g
68
+ message = gpt_35_api(key, text)
69
+ print(message)
70
  bert, phones, tones, lang_ids = get_text(message, "ZH", hps)
71
  with torch.no_grad():
72
  x_tst=phones.to(device).unsqueeze(0)
 
139
  """)
140
  text = gr.TextArea(label="Text", placeholder="Input Text Here",
141
  value="虚拟主播是什么?")
142
+ key = gr.TextArea(label="Key", placeholder="请输入上面提示中获取的gpt key",
143
  value="key")
144
  speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker')
145
  sdp_ratio = gr.Slider(minimum=0.1, maximum=1, value=0.2, step=0.01, label='SDP/DP混合比')
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.1
2
+ matplotlib
3
+ numpy
4
+ numba
5
+ phonemizer
6
+ scipy
7
+ tensorboard
8
+ torch
9
+ torchvision
10
+ Unidecode
11
+ amfm_decompy
12
+ jieba
13
+ transformers
14
+ pypinyin
15
+ cn2an
16
+ gradio
17
+ av
18
+ openai==0.28