Mahiruoshi commited on
Commit
d220bdb
1 Parent(s): 4cc1f98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -49
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import time
3
  import matplotlib.pyplot as plt
4
  import IPython.display as ipd
@@ -19,8 +18,6 @@ from text.symbols import symbols
19
  from text import text_to_sequence
20
  import unicodedata
21
  from scipy.io.wavfile import write
22
- import openai
23
-
24
  def get_text(text, hps):
25
  text_norm = text_to_sequence(text, hps.data.text_cleaners)
26
  if hps.data.add_blank:
@@ -28,6 +25,7 @@ def get_text(text, hps):
28
  text_norm = torch.LongTensor(text_norm)
29
  return text_norm
30
 
 
31
  def get_label(text, label):
32
  if f'[{label}]' in text:
33
  return True, text.replace(f'[{label}]', '')
@@ -35,7 +33,7 @@ def get_label(text, label):
35
  return False, text
36
 
37
  def selection(speaker):
38
- if speaker == "高咲侑(误)":
39
  spk = 0
40
  return spk
41
 
@@ -86,54 +84,23 @@ def selection(speaker):
86
  return spk
87
  elif speaker == "三色绘恋2":
88
  spk = 15
89
- return spk
90
  elif speaker == "派蒙":
91
  spk = 16
92
  return spk
93
- def friend_chat(text,key,call_name,tts_input3):
94
- call_name = call_name
95
- openai.api_key = key
96
- identity = tts_input3
97
- start_sequence = '\n'+str(call_name)+':'
98
- restart_sequence = "\nYou: "
99
- all_text = identity + restart_sequence
100
- if 1 == 1:
101
- prompt0 = text #当期prompt
102
- if text == 'quit':
103
- return prompt0
104
- prompt = identity + prompt0 + start_sequence
105
-
106
- response = openai.Completion.create(
107
- model="text-davinci-003",
108
- prompt=prompt,
109
- temperature=0.5,
110
- max_tokens=1000,
111
- top_p=1.0,
112
- frequency_penalty=0.5,
113
- presence_penalty=0.0,
114
- stop=["\nYou:"]
115
- )
116
- return response['choices'][0]['text'].strip()
117
- def is_japanese(string):
118
- for ch in string:
119
- if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
120
- return True
121
- return False
122
- def sle(language,text,tts_input2,call_name,tts_input3):
123
  if language == "中文":
124
- tts_input1 = "[ZH]" + text.replace('\n','。').replace(' ',',') + "[ZH]"
 
 
 
125
  return tts_input1
126
- if language == "对话":
127
- text = friend_chat(text,tts_input2,call_name,tts_input3).replace('\n','。').replace(' ',',')
128
- text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
129
- return text
130
  elif language == "日文":
131
- tts_input1 = "[JA]" + text.replace('\n','。').replace(' ',',') + "[JA]"
132
  return tts_input1
133
- def infer(language,text,tts_input2,tts_input3,speaker_id,n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
134
- speaker_name = speaker_id
135
  speaker_id = int(selection(speaker_id))
136
- stn_tst = get_text(sle(language,text,tts_input2,speaker_name,tts_input3), hps_ms)
137
  with torch.no_grad():
138
  x_tst = stn_tst.unsqueeze(0).to(dev)
139
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
@@ -144,12 +111,13 @@ def infer(language,text,tts_input2,tts_input3,speaker_id,n_scale= 0.667,n_scale_
144
  spending_time = "推理时间:"+str(t2-t1)+"s"
145
  print(spending_time)
146
  return (hps_ms.data.sampling_rate, audio)
147
- lan = ["中文","日文"]
148
- idols = ["高咲侑(误)","歩夢","かすみ","しずく","果林","愛","彼方","せつ菜","璃奈","栞子","エマ","ランジュ","ミア"]
149
 
150
 
151
- dev = torch.device("cpu")
152
- hps_ms = utils.get_hparams_from_file("2_config.json")
 
153
  net_g_ms = SynthesizerTrn(
154
  len(symbols),
155
  hps_ms.data.filter_length // 2 + 1,
@@ -159,9 +127,10 @@ net_g_ms = SynthesizerTrn(
159
  _ = net_g_ms.eval()
160
 
161
  _ = utils.load_checkpoint("G_842000.pth", net_g_ms, None)
162
-
163
  app = gr.Blocks()
164
 
 
 
165
  with app:
166
  with gr.Tabs():
167
 
 
 
1
  import time
2
  import matplotlib.pyplot as plt
3
  import IPython.display as ipd
 
18
  from text import text_to_sequence
19
  import unicodedata
20
  from scipy.io.wavfile import write
 
 
21
  def get_text(text, hps):
22
  text_norm = text_to_sequence(text, hps.data.text_cleaners)
23
  if hps.data.add_blank:
 
25
  text_norm = torch.LongTensor(text_norm)
26
  return text_norm
27
 
28
+
29
  def get_label(text, label):
30
  if f'[{label}]' in text:
31
  return True, text.replace(f'[{label}]', '')
 
33
  return False, text
34
 
35
  def selection(speaker):
36
+ if speaker == "高咲侑":
37
  spk = 0
38
  return spk
39
 
 
84
  return spk
85
  elif speaker == "三色绘恋2":
86
  spk = 15
 
87
  elif speaker == "派蒙":
88
  spk = 16
89
  return spk
90
+
91
+ def sle(language,tts_input0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if language == "中文":
93
+ tts_input1 = "[ZH]" + tts_input0.replace('\n','。').replace(' ',',') + "[ZH]"
94
+ return tts_input1
95
+ if language == "英文":
96
+ tts_input1 = "[EN]" + tts_input0.replace('\n','.').replace(' ',',') + "[EN]"
97
  return tts_input1
 
 
 
 
98
  elif language == "日文":
99
+ tts_input1 = "[JA]" + tts_input0.replace('\n','。').replace(' ',',') + "[JA]"
100
  return tts_input1
101
+ def infer(language,text,speaker_id, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
 
102
  speaker_id = int(selection(speaker_id))
103
+ stn_tst = get_text(sle(language,text), hps_ms)
104
  with torch.no_grad():
105
  x_tst = stn_tst.unsqueeze(0).to(dev)
106
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
 
111
  spending_time = "推理时间:"+str(t2-t1)+"s"
112
  print(spending_time)
113
  return (hps_ms.data.sampling_rate, audio)
114
+ lan = ["中文","日文","英文"]
115
+ idols = ["高咲侑","歩夢","かすみ","しずく","果林","愛","彼方","せつ菜","璃奈","栞子","エマ","ランジュ","ミア","三色绘恋1","三色绘恋2","派蒙"]
116
 
117
 
118
+
119
+ dev = torch.device("cuda:0")
120
+ hps_ms = utils.get_hparams_from_file("config.json")
121
  net_g_ms = SynthesizerTrn(
122
  len(symbols),
123
  hps_ms.data.filter_length // 2 + 1,
 
127
  _ = net_g_ms.eval()
128
 
129
  _ = utils.load_checkpoint("G_842000.pth", net_g_ms, None)
 
130
  app = gr.Blocks()
131
 
132
+
133
+
134
  with app:
135
  with gr.Tabs():
136