Aboubacar OUATTARA - kaira commited on
Commit
1b0b842
1 Parent(s): 05fb637

use custom tts

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +49 -13
  3. requirements.txt +8 -5
  4. tts.py +395 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,4 +1,8 @@
1
  import concurrent
 
 
 
 
2
  import spaces
3
  from transformers import pipeline
4
  import gradio as gr
@@ -7,6 +11,7 @@ import torchaudio
7
  from resemble_enhance.enhancer.inference import denoise, enhance
8
 
9
  from flore200_codes import flores_codes
 
10
 
11
  # Check if CUDA is available
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -16,8 +21,8 @@ translation_model = "oza75/nllb-600M-mt-french-bambara"
16
  translator = pipeline("translation", model=translation_model, max_length=512)
17
 
18
  # Text-to-Speech pipeline
19
- tts_model = "oza75/bambara-tts-male-001"
20
- tts = pipeline("text-to-speech", model=tts_model, device=device)
21
 
22
 
23
  # Function to translate text to Bambara
@@ -29,11 +34,30 @@ def translate_to_bambara(text, src_lang):
29
 
30
  # Function to convert text to speech
31
  @spaces.GPU
32
- def text_to_speech(bambara_text):
33
- speech = tts(bambara_text)
34
- audio, sr = speech['audio'], speech['sampling_rate']
35
- audio = torch.from_numpy(audio).mean(dim=0)
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return audio, sr
38
 
39
 
@@ -64,14 +88,25 @@ def enhance_speech(audio_array, sampling_rate, solver, nfe, tau, denoise_before_
64
 
65
 
66
  # Define the Gradio interface
67
- def _fn(src_lang, text, solver="Midpoint", nfe=64, prior_temp=0.5, denoise_before_enhancement=False):
 
 
 
 
 
 
 
 
68
  source_lang = flores_codes[src_lang]
69
 
70
  # Step 1: Translate the text to Bambara
71
  bambara_text = translate_to_bambara(text, source_lang)
72
 
73
- # Step 2: Convert the translated text to speech
74
- audio_array, sampling_rate = text_to_speech(bambara_text)
 
 
 
75
 
76
  # Step 3: Enhance the audio
77
  denoised_audio, enhanced_audio = enhance_speech(
@@ -95,13 +130,14 @@ def main():
95
  fn=_fn,
96
  inputs=[
97
  gr.Dropdown(label="Source Language", choices=lang_codes, value='French'),
98
- gr.Textbox(label="Text to Translate"),
 
99
  gr.Dropdown(
100
  choices=["Midpoint", "RK4", "Euler"], value="Midpoint",
101
  label="ODE Solver (Midpoint is recommended)"
102
  ),
103
  gr.Slider(minimum=1, maximum=128, value=64, step=1, label="Number of Function Evaluations"),
104
- gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Prior Temperature"),
105
  gr.Checkbox(value=False, label="Denoise Before Enhancement")
106
  ],
107
  outputs=[
@@ -114,7 +150,7 @@ def main():
114
  description="Translate text to Bambara and convert it to speech with options to enhance audio quality."
115
  )
116
 
117
- app.launch()
118
 
119
 
120
  if __name__ == "__main__":
 
1
  import concurrent
2
+ import os
3
+ import tempfile
4
+ from typing import Optional, Tuple
5
+
6
  import spaces
7
  from transformers import pipeline
8
  import gradio as gr
 
11
  from resemble_enhance.enhancer.inference import denoise, enhance
12
 
13
  from flore200_codes import flores_codes
14
+ from tts import BambaraTTS
15
 
16
  # Check if CUDA is available
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  translator = pipeline("translation", model=translation_model, max_length=512)
22
 
23
  # Text-to-Speech pipeline
24
+ tts_model = "oza75/bambara-tts"
25
+ tts = BambaraTTS(tts_model)
26
 
27
 
28
  # Function to translate text to Bambara
 
34
 
35
  # Function to convert text to speech
36
  @spaces.GPU
37
+ def text_to_speech(bambara_text, reference_audio: Optional[Tuple] = None):
38
+ if reference_audio is not None:
39
+ ref_sr, ref_audio = reference_audio
40
+ ref_audio = torch.from_numpy(ref_audio)
41
+
42
+ # Add a channel dimension if the audio is 1D
43
+ if ref_audio.ndim == 1:
44
+ ref_audio = ref_audio.unsqueeze(0)
45
+
46
+ # Save the reference audio to a temporary file if it's not None
47
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
48
+ torchaudio.save(tmp.name, ref_audio, ref_sr)
49
+ tmp_path = tmp.name
50
+
51
+ # Use the temporary file as the speaker reference
52
+ sr, audio = tts.text_to_speech(bambara_text, speaker_reference_wav_path=tmp_path)
53
+
54
+ # Clean up the temporary file
55
+ os.unlink(tmp_path)
56
+ else:
57
+ # If no reference audio provided, proceed with the default
58
+ sr, audio = tts.text_to_speech(bambara_text)
59
+
60
+ audio = audio.mean(dim=0)
61
  return audio, sr
62
 
63
 
 
88
 
89
 
90
  # Define the Gradio interface
91
+ def _fn(
92
+ src_lang,
93
+ text,
94
+ reference_audio=None,
95
+ solver="Midpoint",
96
+ nfe=64,
97
+ prior_temp=0.5,
98
+ denoise_before_enhancement=False
99
+ ):
100
  source_lang = flores_codes[src_lang]
101
 
102
  # Step 1: Translate the text to Bambara
103
  bambara_text = translate_to_bambara(text, source_lang)
104
 
105
+ # Step 2: Convert the translated text to speech with reference audio
106
+ if reference_audio is not None:
107
+ audio_array, sampling_rate = text_to_speech(bambara_text, reference_audio)
108
+ else:
109
+ audio_array, sampling_rate = text_to_speech(bambara_text)
110
 
111
  # Step 3: Enhance the audio
112
  denoised_audio, enhanced_audio = enhance_speech(
 
130
  fn=_fn,
131
  inputs=[
132
  gr.Dropdown(label="Source Language", choices=lang_codes, value='French'),
133
+ gr.Textbox(label="Text to Translate", lines=3),
134
+ gr.Audio(label="Clone your voice (optional)", type="numpy", format="wav"),
135
  gr.Dropdown(
136
  choices=["Midpoint", "RK4", "Euler"], value="Midpoint",
137
  label="ODE Solver (Midpoint is recommended)"
138
  ),
139
  gr.Slider(minimum=1, maximum=128, value=64, step=1, label="Number of Function Evaluations"),
140
+ gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.01, label="Prior Temperature"),
141
  gr.Checkbox(value=False, label="Denoise Before Enhancement")
142
  ],
143
  outputs=[
 
150
  description="Translate text to Bambara and convert it to speech with options to enhance audio quality."
151
  )
152
 
153
+ app.launch(share=False)
154
 
155
 
156
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- transformers
2
- gradio
3
- torch
4
- torchaudio
5
- spaces
 
 
6
  resemble-enhance==0.0.2.dev240104122303
 
 
1
+ transformers>=4.33.0
2
+ gradio~=4.8.0
3
+ torch~=2.1.1
4
+ torchaudio~=2.1.1
5
+ spaces~=0.26.1
6
+ deepspeed~=0.12.1
7
+ requests~=2.31.0
8
  resemble-enhance==0.0.2.dev240104122303
9
+ git+https://github.com/oza75/coqui-TTS.git@prod
tts.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+
5
+ import numpy as np
6
+ import requests
7
+ import torch
8
+ from typing import Optional, Tuple
9
+
10
+ from TTS.tts.configs.xtts_config import XttsConfig
11
+ from TTS.tts.models.xtts import Xtts
12
+ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, basic_cleaners
13
+ from coqpit import Coqpit
14
+ from huggingface_hub import hf_hub_download, hf_hub_url
15
+ from tqdm import tqdm
16
+
17
+
18
+ def download_file_with_progress(url: str, destination: str):
19
+ """
20
+ Downloads a file from a web URL with a progress bar.
21
+ """
22
+ # Streaming GET request
23
+ response = requests.get(url, stream=True)
24
+
25
+ # Total size in bytes, set to zero if missing
26
+ total_size = int(response.headers.get('content-length', 0))
27
+
28
+ # Using tqdm to display progress
29
+ with open(destination, 'wb') as file, tqdm(desc=destination, total=total_size, unit='B', unit_scale=True,
30
+ unit_divisor=1024) as bar:
31
+ for data in response.iter_content(chunk_size=1024):
32
+ size = file.write(data)
33
+ bar.update(size)
34
+
35
+
36
+ class VoiceBambaraTextPreprocessor:
37
+ def preprocess_batch(self, texts):
38
+ return [self.preprocess(text) for text in texts]
39
+
40
+ def preprocess(self, text: str) -> str:
41
+ text = text.lower()
42
+ text = self.expand_number(text)
43
+ text = self.transliterate_bambara(text)
44
+
45
+ return text
46
+
47
+ def transliterate_bambara(self, text):
48
+ """
49
+ Transliterate Bambara text using a specified mapping of special characters.
50
+
51
+ Parameters:
52
+ - text (str): The original Bambara text.
53
+
54
+ Returns:
55
+ - str: The transliterated text.
56
+ """
57
+ bambara_transliteration = {
58
+ 'ɲ': 'ny',
59
+ 'ɛ': 'è',
60
+ 'ɔ': 'o',
61
+ 'ŋ': 'ng',
62
+ 'ɟ': 'j',
63
+ 'ʔ': "'",
64
+ 'ɣ': 'gh',
65
+ 'ʃ': 'sh',
66
+ 'ߒ': 'n',
67
+ 'ߎ': "u",
68
+ }
69
+
70
+ # Perform the transliteration
71
+ transliterated_text = "".join(bambara_transliteration.get(char, char) for char in text)
72
+
73
+ return transliterated_text
74
+
75
+ def expand_number(self, text):
76
+ """
77
+ Normalize Bambara text for TTS by replacing numerical figures with their word equivalents.
78
+
79
+ Args:
80
+ text (str): The text to be normalized.
81
+
82
+ Returns:
83
+ str: The normalized Bambara text.
84
+ """
85
+
86
+ # A regex pattern to match all numbers
87
+ number_pattern = re.compile(r'\b\d+\b')
88
+
89
+ # Function to replace each number with its Bambara text
90
+ def replace_number_with_text(match):
91
+ number = int(match.group())
92
+ return self.number_to_bambara(number)
93
+
94
+ # Replace each number in the text with its Bambara word equivalent
95
+ normalized_text = number_pattern.sub(replace_number_with_text, text)
96
+
97
+ return normalized_text
98
+
99
+ def number_to_bambara(self, n):
100
+
101
+ """
102
+ Convert a number into its textual representation in Bambara using recursion.
103
+ Args:
104
+ n (int): The number to be converted.
105
+ Returns:
106
+ str: The number expressed in Bambara text.
107
+ Examples:
108
+ >>> number_to_bambara(123)
109
+ 'kɛmɛ ni mugan ni saba'
110
+ Notes:
111
+ This function assumes that 'n' is a non-negative integer.
112
+ """
113
+
114
+ # Bambara numbering rules
115
+ units = ["", "kɛlɛn", "fila", "saba", "naani", "duuru", "wɔrɔ", "wòlonwula", "sɛɛgin", "kɔnɔntɔn"]
116
+ tens = ["", "tan", "mugan", "bisaba", "binaani", "biduuru", "biwɔrɔ", "biwòlonfila", "bisɛɛgin", "bikɔnɔntɔn"]
117
+ hundreds = ["", "kɛmɛ"]
118
+ thousands = ["", "waga"]
119
+ millions = ["", "milyɔn"]
120
+
121
+ # Handle zero explicitly
122
+ if n == 0:
123
+ return "" # bambara does not support zero
124
+
125
+ if n < 10:
126
+ return units[n]
127
+ elif n < 100:
128
+ return tens[n // 10] + (" ni " + self.number_to_bambara(n % 10) if n % 10 > 0 else "")
129
+ elif n < 1000:
130
+ return hundreds[1] + (" " + self.number_to_bambara(n // 100) if n >= 200 else "") + (
131
+ " ni " + self.number_to_bambara(n % 100) if n % 100 > 0 else "")
132
+ elif n < 1_000_000:
133
+ return thousands[1] + " " + self.number_to_bambara(n // 1000) + (
134
+ " ni " + self.number_to_bambara(n % 1000) if n % 1000 > 0 else "")
135
+ else:
136
+ return millions[1] + " " + self.number_to_bambara(n // 1_000_000) + (
137
+ " ni " + self.number_to_bambara(n % 1_000_000) if n % 1_000_000 > 0 else "")
138
+
139
+
140
+ class BambaraTokenizer(VoiceBpeTokenizer):
141
+ """
142
+ A tokenizer for the Bambara language that extends the VoiceBpeTokenizer.
143
+
144
+ Attributes:
145
+ preprocessor: An instance of VoiceBambaraTextPreprocessor for text preprocessing.
146
+ char_limits: A dictionary to hold character limits for languages.
147
+ """
148
+
149
+ def __init__(self, vocab_file: Optional[str] = None):
150
+ """
151
+ Initializes the BambaraTokenizer with a given vocabulary file.
152
+
153
+ Args:
154
+ vocab_file: The path to the vocabulary file, defaults to None.
155
+ """
156
+ super().__init__(vocab_file)
157
+ self.preprocessor = VoiceBambaraTextPreprocessor()
158
+ self.char_limits['bm'] = 200 # Set character limit for Bambara language
159
+
160
+ def preprocess_text(self, txt: str, lang: str) -> str:
161
+ """
162
+ Preprocesses the input text based on the language.
163
+
164
+ Args:
165
+ txt: The text to preprocess.
166
+ lang: The language code of the text.
167
+
168
+ Returns:
169
+ The preprocessed text.
170
+ """
171
+ # Delegate preprocessing to the parent class for non-Bambara languages
172
+ if lang != "bm":
173
+ return super().preprocess_text(txt, lang)
174
+
175
+ # Apply Bambara-specific preprocessing
176
+ txt = self.preprocessor.preprocess(txt)
177
+ txt = basic_cleaners(txt)
178
+ return txt
179
+
180
+
181
+ class BambaraXtts(Xtts):
182
+ """
183
+ A class for the Bambara language that extends the Xtts class.
184
+
185
+ Attributes:
186
+ tokenizer: An instance of BambaraTokenizer.
187
+ """
188
+
189
+ def __init__(self, config: Coqpit):
190
+ """
191
+ Initializes the BambaraXtts with the provided configuration.
192
+
193
+ Args:
194
+ config: An instance of Coqpit containing configuration settings.
195
+ """
196
+ super().__init__(config)
197
+ self.tokenizer = BambaraTokenizer() # Initialize tokenizer for Bambara
198
+ self.init_models()
199
+
200
+ @classmethod
201
+ def init_from_config(cls, config: "XttsConfig", **kwargs) -> "BambaraXtts":
202
+ """
203
+ Class method to create an instance of BambaraXtts from a configuration object.
204
+
205
+ Args:
206
+ config: An instance of XttsConfig containing configuration settings.
207
+ **kwargs: Additional keyword arguments.
208
+
209
+ Returns:
210
+ An instance of BambaraXtts.
211
+ """
212
+ return cls(config)
213
+
214
+
215
+ class BambaraTTS:
216
+ """
217
+ Bambara Text-to-Speech (TTS) class that initializes and uses a TTS model for the Bambara language.
218
+
219
+ Attributes:
220
+ language_code (str): The ISO language code for Bambara.
221
+ checkpoint_repo_or_dir (str): URL or local path to the model checkpoint directory.
222
+ local_dir (str): The directory to store downloaded checkpoints.
223
+ paths (dict): A dictionary of paths to model components.
224
+ config (XttsConfig): Configuration object for the TTS model.
225
+ model (BambaraXtts): The TTS model instance.
226
+ """
227
+
228
+ def __init__(self, checkpoint_repo_or_dir: str, local_dir: Optional[str] = None):
229
+ """
230
+ Initialize the BambaraTTS instance.
231
+
232
+ Args:
233
+ checkpoint_repo_or_dir: A string that represents either a Hugging Face hub repository
234
+ or a local directory where the TTS model checkpoint is located.
235
+ local_dir: An optional string representing a local directory path where model checkpoints
236
+ will be downloaded. If not specified, a default local directory is used based
237
+ on `checkpoint_repo_or_dir`.
238
+
239
+ The initialization process involves setting up local directories for model components,
240
+ ensuring the model checkpoint is available, and loading the model configuration and tokenizer.
241
+ """
242
+
243
+ # Set the language code for Bambara
244
+ self.language_code = 'bm'
245
+
246
+ # Store the checkpoint location and local directory path
247
+ self.checkpoint_repo_or_dir = checkpoint_repo_or_dir
248
+ # If no local directory is provided, use the default based on the checkpoint
249
+ self.local_dir = local_dir if local_dir else self.default_local_dir(checkpoint_repo_or_dir)
250
+
251
+ # Initialize the paths for model components
252
+ self.paths = self.init_paths(self.local_dir)
253
+
254
+ # Ensure the model checkpoint is available locally
255
+ self.ensure_checkpoint_is_downloaded()
256
+
257
+ # Load the model configuration from a JSON file
258
+ self.config = XttsConfig()
259
+ self.config.load_json(self.paths['config.json'])
260
+
261
+ # Initialize the TTS model with the loaded configuration
262
+ self.model = BambaraXtts(self.config)
263
+
264
+ # Set up the tokenizer for the model, using the vocabulary file path
265
+ self.model.tokenizer = BambaraTokenizer(vocab_file=self.paths['vocab.json'])
266
+
267
+ # Load the model checkpoint into the initialized model
268
+ self.model.load_checkpoint(
269
+ self.config,
270
+ vocab_path="fake_vocab.json",
271
+ # The 'fake_vocab.json' is specified because the base model class might
272
+ # attempt to override our tokenizer if a vocab file is present
273
+ checkpoint_dir=self.local_dir,
274
+ use_deepspeed=torch.cuda.is_available() # Utilize DeepSpeed if CUDA is available
275
+ )
276
+
277
+ # Move the model to GPU if CUDA is available
278
+ if torch.cuda.is_available():
279
+ self.model.cuda()
280
+
281
+ self.log_tokenizer()
282
+
283
+ def ensure_checkpoint_is_downloaded(self):
284
+ """
285
+ Ensures that the model checkpoint is downloaded and available locally.
286
+ """
287
+ if os.path.exists(self.checkpoint_repo_or_dir):
288
+ return
289
+
290
+ os.makedirs(self.local_dir, exist_ok=True)
291
+ self.log("Downloading checkpoint from the hub...")
292
+
293
+ for filename, filepath in self.paths.items():
294
+ if os.path.exists(filepath):
295
+ self.log(f"File {filepath} already exists. Skipping...")
296
+ continue
297
+
298
+ file_url = hf_hub_url(repo_id=self.checkpoint_repo_or_dir, filename=filename)
299
+ self.log(f"Downloading {filename} from {file_url}")
300
+ download_file_with_progress(file_url, filepath)
301
+
302
+ self.log("Checkpoint downloaded successfully!")
303
+
304
+ def default_local_dir(self, checkpoint_repo_or_dir: str) -> str:
305
+ """
306
+ Generates a default local directory path for storing the model checkpoint.
307
+
308
+ Args:
309
+ checkpoint_repo_or_dir: The original checkpoint repository or directory path.
310
+
311
+ Returns:
312
+ The default local directory path.
313
+ """
314
+ if os.path.exists(checkpoint_repo_or_dir):
315
+ return checkpoint_repo_or_dir
316
+
317
+ model_path = f"models--{checkpoint_repo_or_dir.replace('/', '--')}"
318
+ local_dir = os.path.join(os.path.expanduser('~'), ".cache", "huggingface", "hub", model_path)
319
+ return local_dir.lower()
320
+
321
+ @staticmethod
322
+ def init_paths(local_dir: str) -> dict:
323
+ """
324
+ Initializes paths to various model components based on the local directory.
325
+
326
+ Args:
327
+ local_dir: The local directory where model components are stored.
328
+
329
+ Returns:
330
+ A dictionary with keys as component names and values as file paths.
331
+ """
332
+ components = ['model.pth', 'config.json', 'vocab.json', 'dvae.pth', 'mel_stats.pth']
333
+ return {name: os.path.join(local_dir, name) for name in components}
334
+
335
+ def text_to_speech(
336
+ self,
337
+ text: str,
338
+ speaker_reference_wav_path: Optional[str] = None,
339
+ temperature: Optional[float] = 0.1,
340
+ enable_text_splitting: bool = False
341
+ ) -> Tuple[int, torch.Tensor]:
342
+ """
343
+ Converts text into speech audio.
344
+
345
+ Args:
346
+ text: The input text to be converted into speech.
347
+ speaker_reference_wav_path: A path to a reference WAV file for the speaker.
348
+ temperature: The temperature parameter for sampling.
349
+ enable_text_splitting: Flag to enable or disable text splitting.
350
+
351
+ Returns:
352
+ A tuple containing the sampling rate and the generated audio tensor.
353
+ """
354
+ if speaker_reference_wav_path is None:
355
+ speaker_reference_wav_path = "reference_audios/male_2.wav"
356
+ self.log("Using default speaker reference ./audios/male_2.wav.")
357
+
358
+ self.log("Computing speaker latents...")
359
+ gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
360
+ audio_path=[speaker_reference_wav_path]
361
+ )
362
+
363
+ self.log("Starting inference...")
364
+ start_time = time.time()
365
+ out = self.model.inference(
366
+ text,
367
+ self.language_code,
368
+ gpt_cond_latent,
369
+ speaker_embedding,
370
+ temperature=temperature,
371
+ enable_text_splitting=enable_text_splitting
372
+ )
373
+ end_time = time.time()
374
+
375
+ audio = torch.tensor(out["wav"]).unsqueeze(0)
376
+ sampling_rate = self.config.model_args.output_sample_rate
377
+
378
+ self.log(f"Speech generated in {end_time - start_time:.2f} seconds.")
379
+
380
+ return sampling_rate, audio
381
+
382
+ def log(self, message: str):
383
+ """
384
+ Logs a message to the console with a uniform format.
385
+
386
+ Args:
387
+ message: The message to be logged.
388
+ """
389
+ print(f"[BambaraTTS] {message}")
390
+
391
+ def log_tokenizer(self):
392
+ """
393
+ Logs the tokenizer information.
394
+ """
395
+ self.log(f"Tokenizer: {self.model.tokenizer}")