Gong Junmin commited on
Commit
260d83d
1 Parent(s): f94ba49

add refer wav support

Browse files
Files changed (2) hide show
  1. app.py +30 -14
  2. emotion_extract.py +6 -4
app.py CHANGED
@@ -5,6 +5,7 @@ import utils
5
  from models import SynthesizerTrn
6
  from text.symbols import symbols
7
  from text import text_to_sequence
 
8
  import numpy as np
9
 
10
 
@@ -32,13 +33,13 @@ emotion_dict = {
32
  "平静2": 3554
33
  }
34
  import random
35
- def tts(txt, emotion):
36
  stn_tst = get_text(txt, hps)
37
  randsample = None
38
  with torch.no_grad():
39
  x_tst = stn_tst.unsqueeze(0)
40
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
41
- sid = torch.LongTensor([0])
42
  if type(emotion) ==int:
43
  emo = torch.FloatTensor(all_emotions[emotion]).unsqueeze(0)
44
  elif emotion == "random":
@@ -57,54 +58,69 @@ def tts(txt, emotion):
57
  return audio, randsample
58
 
59
 
60
- def tts1(text, emotion):
61
  if len(text) > 150:
62
  return "Error: Text is too long", None
63
- audio, _ = tts(text, emotion)
64
  return "Success", (hps.data.sampling_rate, audio)
65
 
66
- def tts2(text):
67
  if len(text) > 150:
68
  return "Error: Text is too long", None
69
- audio, randsample = tts(text, "random_sample")
70
 
71
  return str(randsample), (hps.data.sampling_rate, audio)
72
 
73
- def tts3(text, sample):
74
  if len(text) > 150:
75
  return "Error: Text is too long", None
76
  try:
77
- audio, _ = tts(text, int(sample))
78
  return "Success", (hps.data.sampling_rate, audio)
79
  except:
80
  return "输入参数不为整数或其他错误", None
 
 
 
 
 
 
 
81
  app = gr.Blocks()
82
  with app:
83
  with gr.Tabs():
84
  with gr.TabItem("使用预制情感合成"):
 
85
  tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
86
- tts_input2 = gr.Dropdown(label="情感", choices=list(emotion_dict.keys()), value="平静1")
87
  tts_submit = gr.Button("合成音频", variant="primary")
88
  tts_output1 = gr.Textbox(label="Message")
89
  tts_output2 = gr.Audio(label="Output")
90
- tts_submit.click(tts1, [tts_input1, tts_input2], [tts_output1, tts_output2])
91
  with gr.TabItem("随机抽取训练集样本作为情感参数"):
 
92
  tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
93
  tts_submit = gr.Button("合成音频", variant="primary")
94
  tts_output1 = gr.Textbox(label="随机样本id(可用于第三个tab中合成)")
95
  tts_output2 = gr.Audio(label="Output")
96
- tts_submit.click(tts2, [tts_input1], [tts_output1, tts_output2])
97
 
98
  with gr.TabItem("使用情感样本id作为情感参数"):
99
-
100
  tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
101
  tts_input2 = gr.Number(label="情感样本id", value=2004)
102
  tts_submit = gr.Button("合成音频", variant="primary")
103
  tts_output1 = gr.Textbox(label="Message")
104
  tts_output2 = gr.Audio(label="Output")
105
- tts_submit.click(tts3, [tts_input1, tts_input2], [tts_output1, tts_output2])
106
 
107
  with gr.TabItem("使用参考音频作为情感参数"):
108
- tts_input1 = gr.TextArea(label="text", value="暂未实现")
 
 
 
 
 
 
109
 
110
  app.launch()
 
5
  from models import SynthesizerTrn
6
  from text.symbols import symbols
7
  from text import text_to_sequence
8
+ from emotion_extract import extract_wav
9
  import numpy as np
10
 
11
 
 
33
  "平静2": 3554
34
  }
35
  import random
36
+ def tts(txt, emotion, sid=0):
37
  stn_tst = get_text(txt, hps)
38
  randsample = None
39
  with torch.no_grad():
40
  x_tst = stn_tst.unsqueeze(0)
41
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
42
+ sid = torch.LongTensor([sid])
43
  if type(emotion) ==int:
44
  emo = torch.FloatTensor(all_emotions[emotion]).unsqueeze(0)
45
  elif emotion == "random":
 
58
  return audio, randsample
59
 
60
 
61
+ def tts1(text, emotion, sid=0):
62
  if len(text) > 150:
63
  return "Error: Text is too long", None
64
+ audio, _ = tts(text, emotion, sid)
65
  return "Success", (hps.data.sampling_rate, audio)
66
 
67
+ def tts2(text, sid=0):
68
  if len(text) > 150:
69
  return "Error: Text is too long", None
70
+ audio, randsample = tts(text, "random_sample", sid)
71
 
72
  return str(randsample), (hps.data.sampling_rate, audio)
73
 
74
+ def tts3(text, sample, sid=0):
75
  if len(text) > 150:
76
  return "Error: Text is too long", None
77
  try:
78
+ audio, _ = tts(text, int(sample), sid)
79
  return "Success", (hps.data.sampling_rate, audio)
80
  except:
81
  return "输入参数不为整数或其他错误", None
82
+
83
+
84
+ def tts4(refer_wav_path, text, sid=0):
85
+ audio, _ = tts(text, refer_wav_path, sid)
86
+ return "Success", (hps.data.sampling_rate, audio)
87
+
88
+
89
  app = gr.Blocks()
90
  with app:
91
  with gr.Tabs():
92
  with gr.TabItem("使用预制情感合成"):
93
+ tts_spk_id = gr.Dropdown(label="speaker", choices=list(range(hps.data.n_speakers)), value=0)
94
  tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
95
+ tts_input2 = gr.Dropdown(label="情感", choices=list(emotion_dict.keys()), value="平静1")
96
  tts_submit = gr.Button("合成音频", variant="primary")
97
  tts_output1 = gr.Textbox(label="Message")
98
  tts_output2 = gr.Audio(label="Output")
99
+ tts_submit.click(tts1, [tts_input1, tts_input2, tts_spk_id], [tts_output1, tts_output2])
100
  with gr.TabItem("随机抽取训练集样本作为情感参数"):
101
+ tts_spk_id = gr.Dropdown(label="speaker", choices=list(range(hps.data.n_speakers)), value=0)
102
  tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
103
  tts_submit = gr.Button("合成音频", variant="primary")
104
  tts_output1 = gr.Textbox(label="随机样本id(可用于第三个tab中合成)")
105
  tts_output2 = gr.Audio(label="Output")
106
+ tts_submit.click(tts2, [tts_input1, tts_spk_id], [tts_output1, tts_output2])
107
 
108
  with gr.TabItem("使用情感样本id作为情感参数"):
109
+ tts_spk_id = gr.Dropdown(label="speaker", choices=list(range(hps.data.n_speakers)), value=0)
110
  tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
111
  tts_input2 = gr.Number(label="情感样本id", value=2004)
112
  tts_submit = gr.Button("合成音频", variant="primary")
113
  tts_output1 = gr.Textbox(label="Message")
114
  tts_output2 = gr.Audio(label="Output")
115
+ tts_submit.click(tts3, [tts_input1, tts_input2, tts_spk_id], [tts_output1, tts_output2])
116
 
117
  with gr.TabItem("使用参考音频作为情感参数"):
118
+ tts_spk_id = gr.Dropdown(label="speaker", choices=list(range(hps.data.n_speakers)), value=0)
119
+ tts_refer_wav = gr.File(label="参考音频")
120
+ tts_input1 = gr.TextArea(label="日语文本", value="こんにちは。私わあやちねねです。")
121
+ tts_submit = gr.Button("合成音频", variant="primary")
122
+ tts_output1 = gr.Textbox(label="Message")
123
+ tts_output2 = gr.Audio(label="Output")
124
+ tts_submit.click(tts4, [tts_refer_wav, tts_input1, tts_spk_id], [tts_output1, tts_output2])
125
 
126
  app.launch()
emotion_extract.py CHANGED
@@ -74,6 +74,7 @@ def process_func(
74
  y = processor(x, sampling_rate=sampling_rate)
75
  y = y['input_values'][0]
76
  y = torch.from_numpy(y).to(device)
 
77
 
78
  # run through model
79
  with torch.no_grad():
@@ -89,13 +90,13 @@ def process_func(
89
  # wav, sr = librosa.load(f"{rootpath}/{wavname}", 16000)
90
  # display(ipd.Audio(wav, rate=sr))
91
 
92
- rootpath = "dataset/nene"
93
  embs = []
94
  wavnames = []
95
  def extract_dir(path):
96
  rootpath = path
97
  for idx, wavname in enumerate(os.listdir(rootpath)):
98
- wav, sr =librosa.load(f"{rootpath}/{wavname}", 16000)
99
  emb = process_func(np.expand_dims(wav, 0), sr, embeddings=True)
100
  embs.append(emb)
101
  wavnames.append(wavname)
@@ -103,10 +104,11 @@ def extract_dir(path):
103
  print(idx, wavname)
104
 
105
  def extract_wav(path):
106
- wav, sr = librosa.load(path, 16000)
107
  emb = process_func(np.expand_dims(wav, 0), sr, embeddings=True)
108
  return emb
109
 
110
  if __name__ == '__main__':
111
- for spk in ["serena", "koni", "nyaru","shanoa", "mana"]:
 
112
  extract_dir(f"dataset/{spk}")
 
74
  y = processor(x, sampling_rate=sampling_rate)
75
  y = y['input_values'][0]
76
  y = torch.from_numpy(y).to(device)
77
+ y = y.unsqueeze(0)
78
 
79
  # run through model
80
  with torch.no_grad():
 
90
  # wav, sr = librosa.load(f"{rootpath}/{wavname}", 16000)
91
  # display(ipd.Audio(wav, rate=sr))
92
 
93
+ rootpath = "dataset"
94
  embs = []
95
  wavnames = []
96
  def extract_dir(path):
97
  rootpath = path
98
  for idx, wavname in enumerate(os.listdir(rootpath)):
99
+ wav, sr =librosa.load(f"{rootpath}/{wavname}", sr=16000)
100
  emb = process_func(np.expand_dims(wav, 0), sr, embeddings=True)
101
  embs.append(emb)
102
  wavnames.append(wavname)
 
104
  print(idx, wavname)
105
 
106
  def extract_wav(path):
107
+ wav, sr = librosa.load(path, sr=16000)
108
  emb = process_func(np.expand_dims(wav, 0), sr, embeddings=True)
109
  return emb
110
 
111
  if __name__ == '__main__':
112
+ # for spk in ["serena", "koni", "nyaru","shanoa", "mana"]:
113
+ for spk in ["dubbingx"]:
114
  extract_dir(f"dataset/{spk}")