Spaces:
Sleeping
Sleeping
import os | |
import re | |
import time | |
import numpy as np | |
import requests | |
import torch | |
from typing import Optional, Tuple | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, basic_cleaners | |
from coqpit import Coqpit | |
from huggingface_hub import hf_hub_download, hf_hub_url | |
from tqdm import tqdm | |
def download_file_with_progress(url: str, destination: str, token: str = None): | |
""" | |
Downloads a file from a web URL with a progress bar. Supports Hugging Face API token for gated models. | |
:param url: The URL to download from. | |
:param destination: The destination file path to save the downloaded file. | |
:param token: Hugging Face API token (optional). If not provided, the HF_API_TOKEN from the environment will be used. | |
""" | |
# Use the token passed or fetch from environment variable | |
if token is None: | |
token = os.getenv("HF_TOKEN") | |
# Define headers for the request | |
headers = {} | |
if token: | |
headers['Authorization'] = f'Bearer {token}' | |
# Streaming GET request with headers | |
response = requests.get(url, stream=True, headers=headers) | |
# Total size in bytes, set to zero if missing | |
total_size = int(response.headers.get('content-length', 0)) | |
# Using tqdm to display progress | |
with open(destination, 'wb') as file, tqdm(desc=destination, total=total_size, unit='B', unit_scale=True, | |
unit_divisor=1024) as bar: | |
for data in response.iter_content(chunk_size=1024): | |
size = file.write(data) | |
bar.update(size) | |
class VoiceBambaraTextPreprocessor: | |
def preprocess_batch(self, texts): | |
return [self.preprocess(text) for text in texts] | |
def preprocess(self, text: str) -> str: | |
text = text.lower() | |
text = self.expand_number(text) | |
text = self.transliterate_bambara(text) | |
return text | |
def transliterate_bambara(self, text): | |
""" | |
Transliterate Bambara text using a specified mapping of special characters. | |
Parameters: | |
- text (str): The original Bambara text. | |
Returns: | |
- str: The transliterated text. | |
""" | |
bambara_transliteration = { | |
'ɲ': 'ny', | |
'ɛ': 'è', | |
'ɔ': 'o', | |
'ŋ': 'ng', | |
'ɟ': 'j', | |
'ʔ': "'", | |
'ɣ': 'gh', | |
'ʃ': 'sh', | |
'ߒ': 'n', | |
'ߎ': "u", | |
} | |
# Perform the transliteration | |
transliterated_text = "".join(bambara_transliteration.get(char, char) for char in text) | |
return transliterated_text | |
def expand_number(self, text): | |
""" | |
Normalize Bambara text for TTS by replacing numerical figures with their word equivalents. | |
Args: | |
text (str): The text to be normalized. | |
Returns: | |
str: The normalized Bambara text. | |
""" | |
# A regex pattern to match all numbers | |
number_pattern = re.compile(r'\b\d+\b') | |
# Function to replace each number with its Bambara text | |
def replace_number_with_text(match): | |
number = int(match.group()) | |
return self.number_to_bambara(number) | |
# Replace each number in the text with its Bambara word equivalent | |
normalized_text = number_pattern.sub(replace_number_with_text, text) | |
return normalized_text | |
def number_to_bambara(self, n): | |
""" | |
Convert a number into its textual representation in Bambara using recursion. | |
Args: | |
n (int): The number to be converted. | |
Returns: | |
str: The number expressed in Bambara text. | |
Examples: | |
>>> number_to_bambara(123) | |
'kɛmɛ ni mugan ni saba' | |
Notes: | |
This function assumes that 'n' is a non-negative integer. | |
""" | |
# Bambara numbering rules | |
units = ["", "kɛlɛn", "fila", "saba", "naani", "duuru", "wɔrɔ", "wòlonwula", "sɛɛgin", "kɔnɔntɔn"] | |
tens = ["", "tan", "mugan", "bisaba", "binaani", "biduuru", "biwɔrɔ", "biwòlonfila", "bisɛɛgin", "bikɔnɔntɔn"] | |
hundreds = ["", "kɛmɛ"] | |
thousands = ["", "waga"] | |
millions = ["", "milyɔn"] | |
# Handle zero explicitly | |
if n == 0: | |
return "" # bambara does not support zero | |
if n < 10: | |
return units[n] | |
elif n < 100: | |
return tens[n // 10] + (" ni " + self.number_to_bambara(n % 10) if n % 10 > 0 else "") | |
elif n < 1000: | |
return hundreds[1] + (" " + self.number_to_bambara(n // 100) if n >= 200 else "") + ( | |
" ni " + self.number_to_bambara(n % 100) if n % 100 > 0 else "") | |
elif n < 1_000_000: | |
return thousands[1] + " " + self.number_to_bambara(n // 1000) + ( | |
" ni " + self.number_to_bambara(n % 1000) if n % 1000 > 0 else "") | |
else: | |
return millions[1] + " " + self.number_to_bambara(n // 1_000_000) + ( | |
" ni " + self.number_to_bambara(n % 1_000_000) if n % 1_000_000 > 0 else "") | |
class BambaraTokenizer(VoiceBpeTokenizer): | |
""" | |
A tokenizer for the Bambara language that extends the VoiceBpeTokenizer. | |
Attributes: | |
preprocessor: An instance of VoiceBambaraTextPreprocessor for text preprocessing. | |
char_limits: A dictionary to hold character limits for languages. | |
""" | |
def __init__(self, vocab_file: Optional[str] = None): | |
""" | |
Initializes the BambaraTokenizer with a given vocabulary file. | |
Args: | |
vocab_file: The path to the vocabulary file, defaults to None. | |
""" | |
super().__init__(vocab_file) | |
self.preprocessor = VoiceBambaraTextPreprocessor() | |
self.char_limits['bm'] = 200 # Set character limit for Bambara language | |
def preprocess_text(self, txt: str, lang: str) -> str: | |
""" | |
Preprocesses the input text based on the language. | |
Args: | |
txt: The text to preprocess. | |
lang: The language code of the text. | |
Returns: | |
The preprocessed text. | |
""" | |
# Delegate preprocessing to the parent class for non-Bambara languages | |
if lang != "bm": | |
return super().preprocess_text(txt, lang) | |
# Apply Bambara-specific preprocessing | |
txt = self.preprocessor.preprocess(txt) | |
txt = basic_cleaners(txt) | |
return txt | |
class BambaraXtts(Xtts): | |
""" | |
A class for the Bambara language that extends the Xtts class. | |
Attributes: | |
tokenizer: An instance of BambaraTokenizer. | |
""" | |
def __init__(self, config: Coqpit): | |
""" | |
Initializes the BambaraXtts with the provided configuration. | |
Args: | |
config: An instance of Coqpit containing configuration settings. | |
""" | |
super().__init__(config) | |
self.tokenizer = BambaraTokenizer() # Initialize tokenizer for Bambara | |
self.init_models() | |
def init_from_config(cls, config: "XttsConfig", **kwargs) -> "BambaraXtts": | |
""" | |
Class method to create an instance of BambaraXtts from a configuration object. | |
Args: | |
config: An instance of XttsConfig containing configuration settings. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
An instance of BambaraXtts. | |
""" | |
return cls(config) | |
class BambaraTTS: | |
""" | |
Bambara Text-to-Speech (TTS) class that initializes and uses a TTS model for the Bambara language. | |
Attributes: | |
language_code (str): The ISO language code for Bambara. | |
checkpoint_repo_or_dir (str): URL or local path to the model checkpoint directory. | |
local_dir (str): The directory to store downloaded checkpoints. | |
paths (dict): A dictionary of paths to model components. | |
config (XttsConfig): Configuration object for the TTS model. | |
model (BambaraXtts): The TTS model instance. | |
""" | |
def __init__(self, checkpoint_repo_or_dir: str, local_dir: Optional[str] = None): | |
""" | |
Initialize the BambaraTTS instance. | |
Args: | |
checkpoint_repo_or_dir: A string that represents either a Hugging Face hub repository | |
or a local directory where the TTS model checkpoint is located. | |
local_dir: An optional string representing a local directory path where model checkpoints | |
will be downloaded. If not specified, a default local directory is used based | |
on `checkpoint_repo_or_dir`. | |
The initialization process involves setting up local directories for model components, | |
ensuring the model checkpoint is available, and loading the model configuration and tokenizer. | |
""" | |
# Set the language code for Bambara | |
self.language_code = 'bm' | |
# Store the checkpoint location and local directory path | |
self.checkpoint_repo_or_dir = checkpoint_repo_or_dir | |
# If no local directory is provided, use the default based on the checkpoint | |
self.local_dir = local_dir if local_dir else self.default_local_dir(checkpoint_repo_or_dir) | |
# Initialize the paths for model components | |
self.paths = self.init_paths(self.local_dir) | |
# Ensure the model checkpoint is available locally | |
self.ensure_checkpoint_is_downloaded() | |
# Load the model configuration from a JSON file | |
self.config = XttsConfig() | |
self.config.load_json(self.paths['config.json']) | |
# Initialize the TTS model with the loaded configuration | |
self.model = BambaraXtts(self.config) | |
# Set up the tokenizer for the model, using the vocabulary file path | |
self.model.tokenizer = BambaraTokenizer(vocab_file=self.paths['vocab.json']) | |
# Load the model checkpoint into the initialized model | |
self.model.load_checkpoint( | |
self.config, | |
vocab_path="fake_vocab.json", | |
# The 'fake_vocab.json' is specified because the base model class might | |
# attempt to override our tokenizer if a vocab file is present | |
checkpoint_dir=self.local_dir, | |
# use_deepspeed=torch.cuda.is_available() # Utilize DeepSpeed if CUDA is available | |
use_deepspeed=False # disable because make it fails on huggingface space | |
) | |
# Move the model to GPU if CUDA is available | |
if torch.cuda.is_available(): | |
self.model.cuda() | |
self.log_tokenizer() | |
def ensure_checkpoint_is_downloaded(self): | |
""" | |
Ensures that the model checkpoint is downloaded and available locally. | |
""" | |
if os.path.exists(self.checkpoint_repo_or_dir): | |
return | |
os.makedirs(self.local_dir, exist_ok=True) | |
self.log("Downloading checkpoint from the hub...") | |
for filename, filepath in self.paths.items(): | |
if os.path.exists(filepath): | |
self.log(f"File {filepath} already exists. Skipping...") | |
continue | |
file_url = hf_hub_url(repo_id=self.checkpoint_repo_or_dir, filename=filename) | |
self.log(f"Downloading {filename} from {file_url}") | |
download_file_with_progress(file_url, filepath) | |
self.log("Checkpoint downloaded successfully!") | |
def default_local_dir(self, checkpoint_repo_or_dir: str) -> str: | |
""" | |
Generates a default local directory path for storing the model checkpoint. | |
Args: | |
checkpoint_repo_or_dir: The original checkpoint repository or directory path. | |
Returns: | |
The default local directory path. | |
""" | |
if os.path.exists(checkpoint_repo_or_dir): | |
return checkpoint_repo_or_dir | |
model_path = f"models--{checkpoint_repo_or_dir.replace('/', '--')}" | |
local_dir = os.path.join(os.path.expanduser('~'), "bambara_tts", model_path) | |
return local_dir.lower() | |
def init_paths(local_dir: str) -> dict: | |
""" | |
Initializes paths to various model components based on the local directory. | |
Args: | |
local_dir: The local directory where model components are stored. | |
Returns: | |
A dictionary with keys as component names and values as file paths. | |
""" | |
components = ['model.pth', 'config.json', 'vocab.json', 'dvae.pth', 'mel_stats.pth'] | |
return {name: os.path.join(local_dir, name) for name in components} | |
def text_to_speech( | |
self, | |
text: str, | |
speaker_reference_wav_path: Optional[str] = None, | |
temperature: Optional[float] = 0.1, | |
enable_text_splitting: bool = False | |
) -> Tuple[int, torch.Tensor]: | |
""" | |
Converts text into speech audio. | |
Args: | |
text: The input text to be converted into speech. | |
speaker_reference_wav_path: A path to a reference WAV file for the speaker. | |
temperature: The temperature parameter for sampling. | |
enable_text_splitting: Flag to enable or disable text splitting. | |
Returns: | |
A tuple containing the sampling rate and the generated audio tensor. | |
""" | |
if speaker_reference_wav_path is None: | |
speaker_reference_wav_path = "./audios/male_2.wav" | |
self.log("Using default speaker reference ./audios/male_2.wav.") | |
self.log("Computing speaker latents...") | |
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( | |
audio_path=[speaker_reference_wav_path] | |
) | |
self.log("Starting inference...") | |
start_time = time.time() | |
out = self.model.inference( | |
text, | |
self.language_code, | |
gpt_cond_latent, | |
speaker_embedding, | |
temperature=temperature, | |
enable_text_splitting=enable_text_splitting | |
) | |
end_time = time.time() | |
audio = torch.tensor(out["wav"]).unsqueeze(0).cpu() | |
sampling_rate = torch.tensor(self.config.model_args.output_sample_rate).cpu().item() | |
self.log(f"Speech generated in {end_time - start_time:.2f} seconds.") | |
return sampling_rate, audio | |
def log(self, message: str): | |
""" | |
Logs a message to the console with a uniform format. | |
Args: | |
message: The message to be logged. | |
""" | |
print(f"[BambaraTTS] {message}") | |
def log_tokenizer(self): | |
""" | |
Logs the tokenizer information. | |
""" | |
self.log(f"Tokenizer: {self.model.tokenizer}") | |