alessandro trinca tornidor commited on
Commit
5bae85d
1 Parent(s): 9ab32d7

feat: add text to speech (TTS)

Browse files
aip_trainer/lambdas/lambdaTTS.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ from aip_trainer import app_logger
6
+
7
+
8
+ def get_tts(text: str, language: str):
9
+ from aip_trainer.models import models
10
+
11
+ tmp_dir = Path(tempfile.gettempdir())
12
+ if language == "en" or language == "de":
13
+ try:
14
+ model, _, speaker, sample_rate = models.silero_tts(
15
+ language, output_folder=tmp_dir
16
+ )
17
+ except ValueError:
18
+ model, _, sample_rate, _, _, speaker = models.silero_tts(
19
+ language, output_folder=tmp_dir
20
+ )
21
+ else:
22
+ raise NotImplementedError(f"Not yet tested with {language} error...")
23
+ app_logger.info(f"model speaker #0: {speaker} ...")
24
+
25
+ with tempfile.NamedTemporaryFile(prefix="audio_", suffix=".wav", delete=False) as tmp_audio_file:
26
+ app_logger.info(f"tmp_audio_file output: {tmp_audio_file.name} ...")
27
+ audio_paths = model.save_wav(text=text, speaker=speaker, sample_rate=sample_rate, audio_path=str(tmp_audio_file.name))
28
+ app_logger.info(f"audio_paths output: {audio_paths} ...")
29
+ return audio_paths
30
+
31
+
32
+ """
33
+ Help on method save_wav:
34
+
35
+ save_wav(text=None, ssml_text=None, speaker: str = 'xenia', audio_path: str = '', sample_rate: int = 48000, put_accent=True, put_yo=True) method of <torch_package_0>.multi_acc_v3_package.TTSModelMultiAcc_v3 instance
36
+
37
+
38
+ tmp_dir = tempfile.gettempdir()
39
+ if language == "de":
40
+ model, decoder, _ = silero_stt(
41
+ language="de", version="v4", jit_model="jit_large", output_folder=tmp_dir
42
+ )
43
+ elif language == "en":
44
+ model, decoder, _ = silero_stt(language="en", output_folder=tmp_dir)
45
+ else:
46
+ raise NotImplementedError(
47
+ "currenty works only for 'de' and 'en' languages, not for '{}'.".format(
48
+ language
49
+ )
50
+ )
51
+
52
+ """
aip_trainer/models/models.py CHANGED
@@ -8,64 +8,65 @@ from silero.utils import Decoder
8
  from aip_trainer import app_logger
9
 
10
 
11
- def silero_tts(language='en',
12
- speaker='kseniya_16khz',
13
- **kwargs):
14
- """ Silero Text-To-Speech Models
 
 
 
 
 
 
15
  language (str): language of the model, now available are ['ru', 'en', 'de', 'es', 'fr']
16
  Returns a model and a set of utils
17
  Please see https://github.com/snakers4/silero-models for usage examples
18
  """
19
- from omegaconf import OmegaConf
20
- from silero.tts_utils import apply_tts
21
- from silero.tts_utils import init_jit_model as init_jit_model_tts
 
 
22
 
23
- models_list_file = os.path.join(os.path.dirname(__file__), "..", "..", "models.yml")
24
- if not os.path.exists(models_list_file):
25
- models_list_file = 'latest_silero_models.yml'
26
- if not os.path.exists(models_list_file):
27
- torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
28
- 'latest_silero_models.yml',
29
- progress=False)
30
- assert os.path.exists(models_list_file)
31
- models = OmegaConf.load(models_list_file)
32
  available_languages = list(models.tts_models.keys())
33
- assert language in available_languages, f'Language not in the supported list {available_languages}'
34
- available_speakers = []
35
- speaker_language = {}
36
- for lang in available_languages:
37
- speakers = list(models.tts_models.get(lang).keys())
38
- available_speakers.extend(speakers)
39
- for _ in speakers:
40
- speaker_language[_] = lang
41
- assert speaker in available_speakers, f'Speaker not in the supported list {available_speakers}'
42
- assert language == speaker_language[speaker], f"Incorrect language '{language}' for this speaker, please specify '{speaker_language[speaker]}'"
43
-
44
- model_conf = models.tts_models[language][speaker].latest
45
- if '_v2' in speaker or '_v3' in speaker or 'v3_' in speaker or 'v4_' in speaker:
46
  from torch import package
47
- model_url = model_conf.package
48
- model_dir = os.path.join(os.path.dirname(__file__), "model")
 
49
  os.makedirs(model_dir, exist_ok=True)
50
- model_path = os.path.join(model_dir, os.path.basename(model_url))
51
  if not os.path.isfile(model_path):
52
- torch.hub.download_url_to_file(model_url,
53
- model_path,
54
- progress=True)
55
  imp = package.PackageImporter(model_path)
56
  model = imp.load_pickle("tts_models", "model")
57
- if speaker == 'multi_v2':
58
- avail_speakers = model_conf.speakers
59
- return model, avail_speakers
60
- else:
61
- example_text = model_conf.example
62
- return model, example_text
 
 
 
 
63
  else:
64
- model = init_jit_model_tts(model_conf.jit)
65
- symbols = model_conf.tokenset
66
- example_text = model_conf.example
67
- sample_rate = model_conf.sample_rate
68
- return model, symbols, sample_rate, example_text, apply_tts
 
 
69
 
70
 
71
  def silero_stt(
@@ -74,7 +75,7 @@ def silero_stt(
74
  jit_model="jit",
75
  output_folder: Path | str = None,
76
  **kwargs,
77
- ):
78
  """Modified Silero Speech-To-Text Model(s) function
79
  language (str): language of the model, now available are ['en', 'de', 'es']
80
  version:
@@ -83,8 +84,6 @@ def silero_stt(
83
  Returns a model, decoder object and a set of utils
84
  Please see https://github.com/snakers4/silero-models for usage examples
85
  """
86
- import torch
87
- from omegaconf import OmegaConf
88
  from silero.utils import (
89
  read_audio,
90
  read_batch,
@@ -92,31 +91,13 @@ def silero_stt(
92
  prepare_model_input,
93
  )
94
 
95
- output_folder = (
96
- Path(output_folder)
97
- if output_folder is not None
98
- else Path(os.path.dirname(__file__)) / ".." / ".."
99
- )
100
- models_list_file = output_folder / f"latest_silero_model_{language}.yml"
101
- if not os.path.exists(models_list_file):
102
- app_logger.info(
103
- f"model yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}..."
104
- )
105
- torch.hub.download_url_to_file(
106
- "https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml",
107
- models_list_file,
108
- progress=True,
109
- )
110
- app_logger.info(
111
- f"model yml for '{language}' language, '{version}' version in folder {output_folder}: OK!"
112
- )
113
- assert os.path.exists(models_list_file)
114
- models = OmegaConf.load(models_list_file)
115
- available_languages = list(models.stt_models.keys())
116
- assert language in available_languages
117
-
118
- model, decoder = init_jit_model(
119
- model_url=models.stt_models.get(language).get(version).get(jit_model), output_folder=output_folder, **kwargs
120
  )
121
  utils = (read_batch, split_into_batches, read_audio, prepare_model_input)
122
 
@@ -127,7 +108,7 @@ def init_jit_model(
127
  model_url: str,
128
  device: torch.device = torch.device("cpu"),
129
  output_folder: Path | str = None,
130
- ):
131
  torch.set_grad_enabled(False)
132
 
133
  app_logger.info(
@@ -145,13 +126,9 @@ def init_jit_model(
145
  )
146
 
147
  if not os.path.isfile(model_path):
148
- app_logger.info(
149
- f"downloading model_path: '{model_path}' ..."
150
- )
151
  torch.hub.download_url_to_file(model_url, model_path, progress=True)
152
- app_logger.info(
153
- f"model_path {model_path} downloaded!"
154
- )
155
  model = torch.jit.load(model_path, map_location=device)
156
  model.eval()
157
  return model, Decoder(model.labels)
@@ -174,3 +151,38 @@ def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
174
  )
175
 
176
  return model, decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from aip_trainer import app_logger
9
 
10
 
11
+ default_speaker_dict = {
12
+ "de": {"speaker": "karlsson", "model_id": "v3_de", "sample_rate": 48000},
13
+ "en": {"speaker": "en_0", "model_id": "v3_en", "sample_rate": 48000},
14
+ }
15
+
16
+
17
+ def silero_tts(
18
+ language="en", version="latest", output_folder: Path | str = None, **kwargs
19
+ ):
20
+ """Silero Text-To-Speech Models
21
  language (str): language of the model, now available are ['ru', 'en', 'de', 'es', 'fr']
22
  Returns a model and a set of utils
23
  Please see https://github.com/snakers4/silero-models for usage examples
24
  """
25
+ output_folder = Path(output_folder)
26
+ current_model_lang = default_speaker_dict[language]
27
+ app_logger.info(f"model speaker current_model_lang: {current_model_lang} ...")
28
+ if language in default_speaker_dict:
29
+ model_id = current_model_lang["model_id"]
30
 
31
+ models = get_models(language, output_folder, version, model_type="tts_models")
 
 
 
 
 
 
 
 
32
  available_languages = list(models.tts_models.keys())
33
+ assert (
34
+ language in available_languages
35
+ ), f"Language not in the supported list {available_languages}"
36
+
37
+ tts_models_lang = models.tts_models[language]
38
+ model_conf = tts_models_lang[model_id]
39
+ model_conf_latest = model_conf[version]
40
+ app_logger.info(f"model_conf: {model_conf_latest} ...")
41
+ if "_v2" in model_id or "_v3" in model_id or "v3_" in model_id or "v4_" in model_id:
 
 
 
 
42
  from torch import package
43
+
44
+ model_url = model_conf_latest.package
45
+ model_dir = output_folder / "model"
46
  os.makedirs(model_dir, exist_ok=True)
47
+ model_path = output_folder / os.path.basename(model_url)
48
  if not os.path.isfile(model_path):
49
+ torch.hub.download_url_to_file(model_url, model_path, progress=True)
 
 
50
  imp = package.PackageImporter(model_path)
51
  model = imp.load_pickle("tts_models", "model")
52
+ app_logger.info(
53
+ f"current model_conf_latest.sample_rate:{model_conf_latest.sample_rate} ..."
54
+ )
55
+ sample_rate = current_model_lang["sample_rate"]
56
+ return (
57
+ model,
58
+ model_conf_latest.example,
59
+ current_model_lang["speaker"],
60
+ sample_rate,
61
+ )
62
  else:
63
+ from silero.tts_utils import apply_tts, init_jit_model as init_jit_model_tts
64
+
65
+ model = init_jit_model_tts(model_conf_latest.jit)
66
+ symbols = model_conf_latest.tokenset
67
+ example_text = model_conf_latest.example
68
+ sample_rate = model_conf_latest.sample_rate
69
+ return model, symbols, sample_rate, example_text, apply_tts, model_id
70
 
71
 
72
  def silero_stt(
 
75
  jit_model="jit",
76
  output_folder: Path | str = None,
77
  **kwargs,
78
+ ):
79
  """Modified Silero Speech-To-Text Model(s) function
80
  language (str): language of the model, now available are ['en', 'de', 'es']
81
  version:
 
84
  Returns a model, decoder object and a set of utils
85
  Please see https://github.com/snakers4/silero-models for usage examples
86
  """
 
 
87
  from silero.utils import (
88
  read_audio,
89
  read_batch,
 
91
  prepare_model_input,
92
  )
93
 
94
+ model, decoder = get_latest_model(
95
+ language,
96
+ output_folder,
97
+ version,
98
+ model_type="stt_models",
99
+ jit_model=jit_model,
100
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
  utils = (read_batch, split_into_batches, read_audio, prepare_model_input)
103
 
 
108
  model_url: str,
109
  device: torch.device = torch.device("cpu"),
110
  output_folder: Path | str = None,
111
+ ):
112
  torch.set_grad_enabled(False)
113
 
114
  app_logger.info(
 
126
  )
127
 
128
  if not os.path.isfile(model_path):
129
+ app_logger.info(f"downloading model_path: '{model_path}' ...")
 
 
130
  torch.hub.download_url_to_file(model_url, model_path, progress=True)
131
+ app_logger.info(f"model_path {model_path} downloaded!")
 
 
132
  model = torch.jit.load(model_path, map_location=device)
133
  model.eval()
134
  return model, Decoder(model.labels)
 
151
  )
152
 
153
  return model, decoder
154
+
155
+
156
+ def get_models(language, output_folder, version, model_type):
157
+ from omegaconf import OmegaConf
158
+
159
+ output_folder = (
160
+ Path(output_folder)
161
+ if output_folder is not None
162
+ else Path(os.path.dirname(__file__)) / ".." / ".."
163
+ )
164
+ models_list_file = output_folder / f"latest_silero_model_{language}.yml"
165
+ if not os.path.exists(models_list_file):
166
+ app_logger.info(
167
+ f"model {model_type} yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}..."
168
+ )
169
+ torch.hub.download_url_to_file(
170
+ "https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml",
171
+ models_list_file,
172
+ progress=False,
173
+ )
174
+ assert os.path.exists(models_list_file)
175
+ return OmegaConf.load(models_list_file)
176
+
177
+
178
+ def get_latest_model(language, output_folder, version, model_type, jit_model, **kwargs):
179
+ models = get_models(language, output_folder, version, model_type)
180
+ available_languages = list(models[model_type].keys())
181
+ assert language in available_languages
182
+
183
+ model, decoder = init_jit_model(
184
+ model_url=models[model_type].get(language).get(version).get(jit_model),
185
+ output_folder=output_folder,
186
+ **kwargs,
187
+ )
188
+ return model, decoder
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
 
3
  from aip_trainer import app_logger
4
- from aip_trainer.lambdas import lambdaSpeechToScore
5
 
6
 
7
  js = """
@@ -56,6 +56,14 @@ with gr.Blocks() as gradio_app:
56
  sources=["microphone", "upload"],
57
  type="filepath",
58
  )
 
 
 
 
 
 
 
 
59
  with gr.Column(scale=3, min_width=300):
60
  transcripted_text = gr.Textbox(
61
  lines=2, placeholder=None, label="Transcripted text", visible=False
 
1
  import gradio as gr
2
 
3
  from aip_trainer import app_logger
4
+ from aip_trainer.lambdas import lambdaSpeechToScore, lambdaTTS
5
 
6
 
7
  js = """
 
56
  sources=["microphone", "upload"],
57
  type="filepath",
58
  )
59
+ with gr.Row():
60
+ tts = gr.Audio(label="tts")
61
+ btn = gr.Button(value="TTS")
62
+ btn.click(
63
+ fn=lambdaTTS.get_tts,
64
+ inputs=[learner_transcription, language],
65
+ outputs=tts,
66
+ )
67
  with gr.Column(scale=3, min_width=300):
68
  transcripted_text = gr.Textbox(
69
  lines=2, placeholder=None, label="Transcripted text", visible=False