import os import re import time import numpy as np import requests import torch from typing import Optional, Tuple from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, basic_cleaners from coqpit import Coqpit from huggingface_hub import hf_hub_download, hf_hub_url from tqdm import tqdm def download_file_with_progress(url: str, destination: str): """ Downloads a file from a web URL with a progress bar. """ # Streaming GET request response = requests.get(url, stream=True) # Total size in bytes, set to zero if missing total_size = int(response.headers.get('content-length', 0)) # Using tqdm to display progress with open(destination, 'wb') as file, tqdm(desc=destination, total=total_size, unit='B', unit_scale=True, unit_divisor=1024) as bar: for data in response.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) class VoiceBambaraTextPreprocessor: def preprocess_batch(self, texts): return [self.preprocess(text) for text in texts] def preprocess(self, text: str) -> str: text = text.lower() text = self.expand_number(text) text = self.transliterate_bambara(text) return text def transliterate_bambara(self, text): """ Transliterate Bambara text using a specified mapping of special characters. Parameters: - text (str): The original Bambara text. Returns: - str: The transliterated text. """ bambara_transliteration = { 'ɲ': 'ny', 'ɛ': 'è', 'ɔ': 'o', 'ŋ': 'ng', 'ɟ': 'j', 'ʔ': "'", 'ɣ': 'gh', 'ʃ': 'sh', 'ߒ': 'n', 'ߎ': "u", } # Perform the transliteration transliterated_text = "".join(bambara_transliteration.get(char, char) for char in text) return transliterated_text def expand_number(self, text): """ Normalize Bambara text for TTS by replacing numerical figures with their word equivalents. Args: text (str): The text to be normalized. Returns: str: The normalized Bambara text. """ # A regex pattern to match all numbers number_pattern = re.compile(r'\b\d+\b') # Function to replace each number with its Bambara text def replace_number_with_text(match): number = int(match.group()) return self.number_to_bambara(number) # Replace each number in the text with its Bambara word equivalent normalized_text = number_pattern.sub(replace_number_with_text, text) return normalized_text def number_to_bambara(self, n): """ Convert a number into its textual representation in Bambara using recursion. Args: n (int): The number to be converted. Returns: str: The number expressed in Bambara text. Examples: >>> number_to_bambara(123) 'kɛmɛ ni mugan ni saba' Notes: This function assumes that 'n' is a non-negative integer. """ # Bambara numbering rules units = ["", "kɛlɛn", "fila", "saba", "naani", "duuru", "wɔrɔ", "wòlonwula", "sɛɛgin", "kɔnɔntɔn"] tens = ["", "tan", "mugan", "bisaba", "binaani", "biduuru", "biwɔrɔ", "biwòlonfila", "bisɛɛgin", "bikɔnɔntɔn"] hundreds = ["", "kɛmɛ"] thousands = ["", "waga"] millions = ["", "milyɔn"] # Handle zero explicitly if n == 0: return "" # bambara does not support zero if n < 10: return units[n] elif n < 100: return tens[n // 10] + (" ni " + self.number_to_bambara(n % 10) if n % 10 > 0 else "") elif n < 1000: return hundreds[1] + (" " + self.number_to_bambara(n // 100) if n >= 200 else "") + ( " ni " + self.number_to_bambara(n % 100) if n % 100 > 0 else "") elif n < 1_000_000: return thousands[1] + " " + self.number_to_bambara(n // 1000) + ( " ni " + self.number_to_bambara(n % 1000) if n % 1000 > 0 else "") else: return millions[1] + " " + self.number_to_bambara(n // 1_000_000) + ( " ni " + self.number_to_bambara(n % 1_000_000) if n % 1_000_000 > 0 else "") class BambaraTokenizer(VoiceBpeTokenizer): """ A tokenizer for the Bambara language that extends the VoiceBpeTokenizer. Attributes: preprocessor: An instance of VoiceBambaraTextPreprocessor for text preprocessing. char_limits: A dictionary to hold character limits for languages. """ def __init__(self, vocab_file: Optional[str] = None): """ Initializes the BambaraTokenizer with a given vocabulary file. Args: vocab_file: The path to the vocabulary file, defaults to None. """ super().__init__(vocab_file) self.preprocessor = VoiceBambaraTextPreprocessor() self.char_limits['bm'] = 200 # Set character limit for Bambara language def preprocess_text(self, txt: str, lang: str) -> str: """ Preprocesses the input text based on the language. Args: txt: The text to preprocess. lang: The language code of the text. Returns: The preprocessed text. """ # Delegate preprocessing to the parent class for non-Bambara languages if lang != "bm": return super().preprocess_text(txt, lang) # Apply Bambara-specific preprocessing txt = self.preprocessor.preprocess(txt) txt = basic_cleaners(txt) return txt class BambaraXtts(Xtts): """ A class for the Bambara language that extends the Xtts class. Attributes: tokenizer: An instance of BambaraTokenizer. """ def __init__(self, config: Coqpit): """ Initializes the BambaraXtts with the provided configuration. Args: config: An instance of Coqpit containing configuration settings. """ super().__init__(config) self.tokenizer = BambaraTokenizer() # Initialize tokenizer for Bambara self.init_models() @classmethod def init_from_config(cls, config: "XttsConfig", **kwargs) -> "BambaraXtts": """ Class method to create an instance of BambaraXtts from a configuration object. Args: config: An instance of XttsConfig containing configuration settings. **kwargs: Additional keyword arguments. Returns: An instance of BambaraXtts. """ return cls(config) class BambaraTTS: """ Bambara Text-to-Speech (TTS) class that initializes and uses a TTS model for the Bambara language. Attributes: language_code (str): The ISO language code for Bambara. checkpoint_repo_or_dir (str): URL or local path to the model checkpoint directory. local_dir (str): The directory to store downloaded checkpoints. paths (dict): A dictionary of paths to model components. config (XttsConfig): Configuration object for the TTS model. model (BambaraXtts): The TTS model instance. """ def __init__(self, checkpoint_repo_or_dir: str, local_dir: Optional[str] = None): """ Initialize the BambaraTTS instance. Args: checkpoint_repo_or_dir: A string that represents either a Hugging Face hub repository or a local directory where the TTS model checkpoint is located. local_dir: An optional string representing a local directory path where model checkpoints will be downloaded. If not specified, a default local directory is used based on `checkpoint_repo_or_dir`. The initialization process involves setting up local directories for model components, ensuring the model checkpoint is available, and loading the model configuration and tokenizer. """ # Set the language code for Bambara self.language_code = 'bm' # Store the checkpoint location and local directory path self.checkpoint_repo_or_dir = checkpoint_repo_or_dir # If no local directory is provided, use the default based on the checkpoint self.local_dir = local_dir if local_dir else self.default_local_dir(checkpoint_repo_or_dir) # Initialize the paths for model components self.paths = self.init_paths(self.local_dir) # Ensure the model checkpoint is available locally self.ensure_checkpoint_is_downloaded() # Load the model configuration from a JSON file self.config = XttsConfig() self.config.load_json(self.paths['config.json']) # Initialize the TTS model with the loaded configuration self.model = BambaraXtts(self.config) # Set up the tokenizer for the model, using the vocabulary file path self.model.tokenizer = BambaraTokenizer(vocab_file=self.paths['vocab.json']) # Load the model checkpoint into the initialized model self.model.load_checkpoint( self.config, vocab_path="fake_vocab.json", # The 'fake_vocab.json' is specified because the base model class might # attempt to override our tokenizer if a vocab file is present checkpoint_dir=self.local_dir, # use_deepspeed=torch.cuda.is_available() # Utilize DeepSpeed if CUDA is available use_deepspeed=False # disable because make it fails on huggingface space ) # Move the model to GPU if CUDA is available if torch.cuda.is_available(): self.model.cuda() self.log_tokenizer() def ensure_checkpoint_is_downloaded(self): """ Ensures that the model checkpoint is downloaded and available locally. """ if os.path.exists(self.checkpoint_repo_or_dir): return os.makedirs(self.local_dir, exist_ok=True) self.log("Downloading checkpoint from the hub...") for filename, filepath in self.paths.items(): if os.path.exists(filepath): self.log(f"File {filepath} already exists. Skipping...") continue file_url = hf_hub_url(repo_id=self.checkpoint_repo_or_dir, filename=filename) self.log(f"Downloading {filename} from {file_url}") download_file_with_progress(file_url, filepath) self.log("Checkpoint downloaded successfully!") def default_local_dir(self, checkpoint_repo_or_dir: str) -> str: """ Generates a default local directory path for storing the model checkpoint. Args: checkpoint_repo_or_dir: The original checkpoint repository or directory path. Returns: The default local directory path. """ if os.path.exists(checkpoint_repo_or_dir): return checkpoint_repo_or_dir model_path = f"models--{checkpoint_repo_or_dir.replace('/', '--')}" local_dir = os.path.join(os.path.expanduser('~'), "bambara_tts", model_path) return local_dir.lower() @staticmethod def init_paths(local_dir: str) -> dict: """ Initializes paths to various model components based on the local directory. Args: local_dir: The local directory where model components are stored. Returns: A dictionary with keys as component names and values as file paths. """ components = ['model.pth', 'config.json', 'vocab.json', 'dvae.pth', 'mel_stats.pth'] return {name: os.path.join(local_dir, name) for name in components} def text_to_speech( self, text: str, speaker_reference_wav_path: Optional[str] = None, temperature: Optional[float] = 0.1, enable_text_splitting: bool = False ) -> Tuple[int, torch.Tensor]: """ Converts text into speech audio. Args: text: The input text to be converted into speech. speaker_reference_wav_path: A path to a reference WAV file for the speaker. temperature: The temperature parameter for sampling. enable_text_splitting: Flag to enable or disable text splitting. Returns: A tuple containing the sampling rate and the generated audio tensor. """ if speaker_reference_wav_path is None: speaker_reference_wav_path = "./audios/male_2.wav" self.log("Using default speaker reference ./audios/male_2.wav.") self.log("Computing speaker latents...") gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( audio_path=[speaker_reference_wav_path] ) self.log("Starting inference...") start_time = time.time() out = self.model.inference( text, self.language_code, gpt_cond_latent, speaker_embedding, temperature=temperature, enable_text_splitting=enable_text_splitting ) end_time = time.time() audio = torch.tensor(out["wav"]).unsqueeze(0).cpu() sampling_rate = torch.tensor(self.config.model_args.output_sample_rate).cpu().item() self.log(f"Speech generated in {end_time - start_time:.2f} seconds.") return sampling_rate, audio def log(self, message: str): """ Logs a message to the console with a uniform format. Args: message: The message to be logged. """ print(f"[BambaraTTS] {message}") def log_tokenizer(self): """ Logs the tokenizer information. """ self.log(f"Tokenizer: {self.model.tokenizer}")