Mahiruoshi commited on
Commit
5fa1ed0
β€’
1 Parent(s): 37ea278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -45
app.py CHANGED
@@ -1,8 +1,6 @@
1
- import sys, os
2
-
3
- if sys.platform == "darwin":
4
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
 
 
6
  import logging
7
 
8
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -10,7 +8,9 @@ logging.getLogger("markdown_it").setLevel(logging.WARNING)
10
  logging.getLogger("urllib3").setLevel(logging.WARNING)
11
  logging.getLogger("matplotlib").setLevel(logging.WARNING)
12
 
13
- logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
 
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -25,9 +25,14 @@ from text.cleaner import clean_text
25
  import gradio as gr
26
  import webbrowser
27
 
28
-
29
  net_g = None
30
 
 
 
 
 
 
 
31
 
32
  def get_text(text, language_str, hps):
33
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
@@ -63,9 +68,10 @@ def get_text(text, language_str, hps):
63
  language = torch.LongTensor(language)
64
  return bert, ja_bert, phone, tone, language
65
 
66
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
 
67
  global net_g
68
- bert, ja_bert, phones, tones, lang_ids = get_text(text, "JP", hps)
69
  with torch.no_grad():
70
  x_tst = phones.to(device).unsqueeze(0)
71
  tones = tones.to(device).unsqueeze(0)
@@ -96,26 +102,48 @@ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
96
  del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
97
  return audio
98
 
99
- def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
 
 
 
100
  with torch.no_grad():
101
- audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
 
 
 
 
 
 
 
 
 
102
  return "Success", (hps.data.sampling_rate, audio)
103
 
104
 
105
  if __name__ == "__main__":
106
  parser = argparse.ArgumentParser()
107
- parser.add_argument("--model_dir", default="./logs/Mygo/G_44000.pth", help="path of your model")
108
- parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
109
- parser.add_argument("--share", default=False, help="make link public")
110
- parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  args = parser.parse_args()
113
  if args.debug:
114
  logger.info("Enable DEBUG-LEVEL log")
115
  logging.basicConfig(level=logging.DEBUG)
116
- hps = utils.get_hparams_from_file(args.config_dir)
117
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
118
- '''
119
  device = (
120
  "cuda:0"
121
  if torch.cuda.is_available()
@@ -125,42 +153,72 @@ if __name__ == "__main__":
125
  else "cpu"
126
  )
127
  )
128
- '''
129
  net_g = SynthesizerTrn(
130
  len(symbols),
131
  hps.data.filter_length // 2 + 1,
132
  hps.train.segment_size // hps.data.hop_length,
133
  n_speakers=hps.data.n_speakers,
134
- **hps.model).to(device)
 
135
  _ = net_g.eval()
136
 
137
- _ = utils.load_checkpoint(args.model_dir, net_g, None, skip_optimizer=True)
138
 
139
  speaker_ids = hps.data.spk2id
140
  speakers = list(speaker_ids.keys())
 
141
  with gr.Blocks() as app:
142
- with gr.Row():
143
- with gr.Column():
144
- gr.Markdown(value="""
145
- Mygo Vits-bert
146
- """)
147
- text = gr.TextArea(label="Text", placeholder="Input Text Here",
148
- value="η§γŸγ‘γ―γ€δΈ€η·’γ«γ―γ„γ‚‰γ‚Œγͺい。")
149
- speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker')
150
- sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label='SDP/DPζ··οΏ½οΏ½οΏ½ζ―”')
151
- noise_scale = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.1, label='ζ„Ÿζƒ…θ°ƒθŠ‚')
152
- noise_scale_w = gr.Slider(minimum=0.1, maximum=1.4, value=0.8, step=0.1, label='ιŸ³η΄ ι•ΏεΊ¦')
153
- length_scale = gr.Slider(minimum=0.1, maximum=2, value=1, step=0.1, label='η”Ÿζˆι•ΏεΊ¦')
154
- btn = gr.Button("η”ŸζˆοΌ", variant="primary")
155
- with gr.Column():
156
- text_output = gr.Textbox(label="Message")
157
- audio_output = gr.Audio(label="Output Audio")
158
-
159
- btn.click(tts_fn,
160
- inputs=[text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale],
161
- outputs=[text_output, audio_output])
162
-
163
- # webbrowser.open("http://127.0.0.1:6006")
164
- # app.launch(server_port=6006, show_error=True)
165
-
166
- app.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
 
 
 
2
 
3
+ import sys, os
4
  import logging
5
 
6
  logging.getLogger("numba").setLevel(logging.WARNING)
 
8
  logging.getLogger("urllib3").setLevel(logging.WARNING)
9
  logging.getLogger("matplotlib").setLevel(logging.WARNING)
10
 
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
13
+ )
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
25
  import gradio as gr
26
  import webbrowser
27
 
 
28
  net_g = None
29
 
30
+ if sys.platform == "darwin" and torch.backends.mps.is_available():
31
+ device = "mps"
32
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
33
+ else:
34
+ device = "cuda"
35
+
36
 
37
  def get_text(text, language_str, hps):
38
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
 
68
  language = torch.LongTensor(language)
69
  return bert, ja_bert, phone, tone, language
70
 
71
+
72
+ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
73
  global net_g
74
+ bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
75
  with torch.no_grad():
76
  x_tst = phones.to(device).unsqueeze(0)
77
  tones = tones.to(device).unsqueeze(0)
 
102
  del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
103
  return audio
104
 
105
+
106
+ def tts_fn(
107
+ text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language
108
+ ):
109
  with torch.no_grad():
110
+ audio = infer(
111
+ text,
112
+ sdp_ratio=sdp_ratio,
113
+ noise_scale=noise_scale,
114
+ noise_scale_w=noise_scale_w,
115
+ length_scale=length_scale,
116
+ sid=speaker,
117
+ language=language,
118
+ )
119
+ torch.cuda.empty_cache()
120
  return "Success", (hps.data.sampling_rate, audio)
121
 
122
 
123
  if __name__ == "__main__":
124
  parser = argparse.ArgumentParser()
125
+ parser.add_argument(
126
+ "-m", "--model", default="./logs/Mygo/G_63000.pth", help="path of your model"
127
+ )
128
+ parser.add_argument(
129
+ "-c",
130
+ "--config",
131
+ default="./logs/Mygo/config.json",
132
+ help="path of your config file",
133
+ )
134
+ parser.add_argument(
135
+ "--share", default=True, help="make link public", action="store_true"
136
+ )
137
+ parser.add_argument(
138
+ "-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log"
139
+ )
140
 
141
  args = parser.parse_args()
142
  if args.debug:
143
  logger.info("Enable DEBUG-LEVEL log")
144
  logging.basicConfig(level=logging.DEBUG)
145
+ hps = utils.get_hparams_from_file(args.config)
146
+
 
147
  device = (
148
  "cuda:0"
149
  if torch.cuda.is_available()
 
153
  else "cpu"
154
  )
155
  )
 
156
  net_g = SynthesizerTrn(
157
  len(symbols),
158
  hps.data.filter_length // 2 + 1,
159
  hps.train.segment_size // hps.data.hop_length,
160
  n_speakers=hps.data.n_speakers,
161
+ **hps.model,
162
+ ).to(device)
163
  _ = net_g.eval()
164
 
165
+ _ = utils.load_checkpoint(args.model, net_g, None, skip_optimizer=True)
166
 
167
  speaker_ids = hps.data.spk2id
168
  speakers = list(speaker_ids.keys())
169
+ languages = ["ZH", "JP"]
170
  with gr.Blocks() as app:
171
+ for name in speakers:
172
+ with gr.TabItem(name):
173
+ with gr.Row():
174
+ with gr.Column():
175
+ with gr.Row():
176
+ gr.Markdown(
177
+ '<div align="center">'
178
+ f'<img style="width:auto;height:400px;" src="file/image/{name}.png">'
179
+ '</div>'
180
+ )
181
+ text = gr.TextArea(
182
+ label="Text",
183
+ placeholder="Input Text Here",
184
+ value="η§γŸγ‘γ―γ€δΈ€η·’γ«γ―γ„γ‚‰γ‚Œγͺい。",
185
+ )
186
+ speaker = gr.Dropdown(
187
+ choices=speakers, value=name, label="Speaker"
188
+ )
189
+ with gr.Column():
190
+ text_output = gr.Textbox(label="Message")
191
+ audio_output = gr.Audio(label="Output Audio")
192
+ btn = gr.Button("Generate!", variant="primary")
193
+ sdp_ratio = gr.Slider(
194
+ minimum=0, maximum=1, value=0.2, step=0.01, label="SDP Ratio"
195
+ )
196
+ noise_scale = gr.Slider(
197
+ minimum=0.1, maximum=2, value=0.6, step=0.01, label="Noise Scale"
198
+ )
199
+ noise_scale_w = gr.Slider(
200
+ minimum=0.1, maximum=2, value=0.8, step=0.01, label="Noise Scale W"
201
+ )
202
+ length_scale = gr.Slider(
203
+ minimum=0.1, maximum=2, value=1, step=0.01, label="Length Scale"
204
+ )
205
+ language = gr.Dropdown(
206
+ choices=languages, value=languages[1], label="Language"
207
+ )
208
+
209
+ btn.click(
210
+ tts_fn,
211
+ inputs=[
212
+ text,
213
+ speaker,
214
+ sdp_ratio,
215
+ noise_scale,
216
+ noise_scale_w,
217
+ length_scale,
218
+ language,
219
+ ],
220
+ outputs=[text_output, audio_output],
221
+ )
222
+
223
+ webbrowser.open("http://127.0.0.1:7860")
224
+ app.launch(share=args.share)