bambara-mt / tts.py
oza75's picture
add token for downloading gated model.
1d280b8
import os
import re
import time
import numpy as np
import requests
import torch
from typing import Optional, Tuple
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, basic_cleaners
from coqpit import Coqpit
from huggingface_hub import hf_hub_download, hf_hub_url
from tqdm import tqdm
def download_file_with_progress(url: str, destination: str, token: str = None):
"""
Downloads a file from a web URL with a progress bar. Supports Hugging Face API token for gated models.
:param url: The URL to download from.
:param destination: The destination file path to save the downloaded file.
:param token: Hugging Face API token (optional). If not provided, the HF_API_TOKEN from the environment will be used.
"""
# Use the token passed or fetch from environment variable
if token is None:
token = os.getenv("HF_TOKEN")
# Define headers for the request
headers = {}
if token:
headers['Authorization'] = f'Bearer {token}'
# Streaming GET request with headers
response = requests.get(url, stream=True, headers=headers)
# Total size in bytes, set to zero if missing
total_size = int(response.headers.get('content-length', 0))
# Using tqdm to display progress
with open(destination, 'wb') as file, tqdm(desc=destination, total=total_size, unit='B', unit_scale=True,
unit_divisor=1024) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
class VoiceBambaraTextPreprocessor:
def preprocess_batch(self, texts):
return [self.preprocess(text) for text in texts]
def preprocess(self, text: str) -> str:
text = text.lower()
text = self.expand_number(text)
text = self.transliterate_bambara(text)
return text
def transliterate_bambara(self, text):
"""
Transliterate Bambara text using a specified mapping of special characters.
Parameters:
- text (str): The original Bambara text.
Returns:
- str: The transliterated text.
"""
bambara_transliteration = {
'ɲ': 'ny',
'ɛ': 'è',
'ɔ': 'o',
'ŋ': 'ng',
'ɟ': 'j',
'ʔ': "'",
'ɣ': 'gh',
'ʃ': 'sh',
'ߒ': 'n',
'ߎ': "u",
}
# Perform the transliteration
transliterated_text = "".join(bambara_transliteration.get(char, char) for char in text)
return transliterated_text
def expand_number(self, text):
"""
Normalize Bambara text for TTS by replacing numerical figures with their word equivalents.
Args:
text (str): The text to be normalized.
Returns:
str: The normalized Bambara text.
"""
# A regex pattern to match all numbers
number_pattern = re.compile(r'\b\d+\b')
# Function to replace each number with its Bambara text
def replace_number_with_text(match):
number = int(match.group())
return self.number_to_bambara(number)
# Replace each number in the text with its Bambara word equivalent
normalized_text = number_pattern.sub(replace_number_with_text, text)
return normalized_text
def number_to_bambara(self, n):
"""
Convert a number into its textual representation in Bambara using recursion.
Args:
n (int): The number to be converted.
Returns:
str: The number expressed in Bambara text.
Examples:
>>> number_to_bambara(123)
'kɛmɛ ni mugan ni saba'
Notes:
This function assumes that 'n' is a non-negative integer.
"""
# Bambara numbering rules
units = ["", "kɛlɛn", "fila", "saba", "naani", "duuru", "wɔrɔ", "wòlonwula", "sɛɛgin", "kɔnɔntɔn"]
tens = ["", "tan", "mugan", "bisaba", "binaani", "biduuru", "biwɔrɔ", "biwòlonfila", "bisɛɛgin", "bikɔnɔntɔn"]
hundreds = ["", "kɛmɛ"]
thousands = ["", "waga"]
millions = ["", "milyɔn"]
# Handle zero explicitly
if n == 0:
return "" # bambara does not support zero
if n < 10:
return units[n]
elif n < 100:
return tens[n // 10] + (" ni " + self.number_to_bambara(n % 10) if n % 10 > 0 else "")
elif n < 1000:
return hundreds[1] + (" " + self.number_to_bambara(n // 100) if n >= 200 else "") + (
" ni " + self.number_to_bambara(n % 100) if n % 100 > 0 else "")
elif n < 1_000_000:
return thousands[1] + " " + self.number_to_bambara(n // 1000) + (
" ni " + self.number_to_bambara(n % 1000) if n % 1000 > 0 else "")
else:
return millions[1] + " " + self.number_to_bambara(n // 1_000_000) + (
" ni " + self.number_to_bambara(n % 1_000_000) if n % 1_000_000 > 0 else "")
class BambaraTokenizer(VoiceBpeTokenizer):
"""
A tokenizer for the Bambara language that extends the VoiceBpeTokenizer.
Attributes:
preprocessor: An instance of VoiceBambaraTextPreprocessor for text preprocessing.
char_limits: A dictionary to hold character limits for languages.
"""
def __init__(self, vocab_file: Optional[str] = None):
"""
Initializes the BambaraTokenizer with a given vocabulary file.
Args:
vocab_file: The path to the vocabulary file, defaults to None.
"""
super().__init__(vocab_file)
self.preprocessor = VoiceBambaraTextPreprocessor()
self.char_limits['bm'] = 200 # Set character limit for Bambara language
def preprocess_text(self, txt: str, lang: str) -> str:
"""
Preprocesses the input text based on the language.
Args:
txt: The text to preprocess.
lang: The language code of the text.
Returns:
The preprocessed text.
"""
# Delegate preprocessing to the parent class for non-Bambara languages
if lang != "bm":
return super().preprocess_text(txt, lang)
# Apply Bambara-specific preprocessing
txt = self.preprocessor.preprocess(txt)
txt = basic_cleaners(txt)
return txt
class BambaraXtts(Xtts):
"""
A class for the Bambara language that extends the Xtts class.
Attributes:
tokenizer: An instance of BambaraTokenizer.
"""
def __init__(self, config: Coqpit):
"""
Initializes the BambaraXtts with the provided configuration.
Args:
config: An instance of Coqpit containing configuration settings.
"""
super().__init__(config)
self.tokenizer = BambaraTokenizer() # Initialize tokenizer for Bambara
self.init_models()
@classmethod
def init_from_config(cls, config: "XttsConfig", **kwargs) -> "BambaraXtts":
"""
Class method to create an instance of BambaraXtts from a configuration object.
Args:
config: An instance of XttsConfig containing configuration settings.
**kwargs: Additional keyword arguments.
Returns:
An instance of BambaraXtts.
"""
return cls(config)
class BambaraTTS:
"""
Bambara Text-to-Speech (TTS) class that initializes and uses a TTS model for the Bambara language.
Attributes:
language_code (str): The ISO language code for Bambara.
checkpoint_repo_or_dir (str): URL or local path to the model checkpoint directory.
local_dir (str): The directory to store downloaded checkpoints.
paths (dict): A dictionary of paths to model components.
config (XttsConfig): Configuration object for the TTS model.
model (BambaraXtts): The TTS model instance.
"""
def __init__(self, checkpoint_repo_or_dir: str, local_dir: Optional[str] = None):
"""
Initialize the BambaraTTS instance.
Args:
checkpoint_repo_or_dir: A string that represents either a Hugging Face hub repository
or a local directory where the TTS model checkpoint is located.
local_dir: An optional string representing a local directory path where model checkpoints
will be downloaded. If not specified, a default local directory is used based
on `checkpoint_repo_or_dir`.
The initialization process involves setting up local directories for model components,
ensuring the model checkpoint is available, and loading the model configuration and tokenizer.
"""
# Set the language code for Bambara
self.language_code = 'bm'
# Store the checkpoint location and local directory path
self.checkpoint_repo_or_dir = checkpoint_repo_or_dir
# If no local directory is provided, use the default based on the checkpoint
self.local_dir = local_dir if local_dir else self.default_local_dir(checkpoint_repo_or_dir)
# Initialize the paths for model components
self.paths = self.init_paths(self.local_dir)
# Ensure the model checkpoint is available locally
self.ensure_checkpoint_is_downloaded()
# Load the model configuration from a JSON file
self.config = XttsConfig()
self.config.load_json(self.paths['config.json'])
# Initialize the TTS model with the loaded configuration
self.model = BambaraXtts(self.config)
# Set up the tokenizer for the model, using the vocabulary file path
self.model.tokenizer = BambaraTokenizer(vocab_file=self.paths['vocab.json'])
# Load the model checkpoint into the initialized model
self.model.load_checkpoint(
self.config,
vocab_path="fake_vocab.json",
# The 'fake_vocab.json' is specified because the base model class might
# attempt to override our tokenizer if a vocab file is present
checkpoint_dir=self.local_dir,
# use_deepspeed=torch.cuda.is_available() # Utilize DeepSpeed if CUDA is available
use_deepspeed=False # disable because make it fails on huggingface space
)
# Move the model to GPU if CUDA is available
if torch.cuda.is_available():
self.model.cuda()
self.log_tokenizer()
def ensure_checkpoint_is_downloaded(self):
"""
Ensures that the model checkpoint is downloaded and available locally.
"""
if os.path.exists(self.checkpoint_repo_or_dir):
return
os.makedirs(self.local_dir, exist_ok=True)
self.log("Downloading checkpoint from the hub...")
for filename, filepath in self.paths.items():
if os.path.exists(filepath):
self.log(f"File {filepath} already exists. Skipping...")
continue
file_url = hf_hub_url(repo_id=self.checkpoint_repo_or_dir, filename=filename)
self.log(f"Downloading {filename} from {file_url}")
download_file_with_progress(file_url, filepath)
self.log("Checkpoint downloaded successfully!")
def default_local_dir(self, checkpoint_repo_or_dir: str) -> str:
"""
Generates a default local directory path for storing the model checkpoint.
Args:
checkpoint_repo_or_dir: The original checkpoint repository or directory path.
Returns:
The default local directory path.
"""
if os.path.exists(checkpoint_repo_or_dir):
return checkpoint_repo_or_dir
model_path = f"models--{checkpoint_repo_or_dir.replace('/', '--')}"
local_dir = os.path.join(os.path.expanduser('~'), "bambara_tts", model_path)
return local_dir.lower()
@staticmethod
def init_paths(local_dir: str) -> dict:
"""
Initializes paths to various model components based on the local directory.
Args:
local_dir: The local directory where model components are stored.
Returns:
A dictionary with keys as component names and values as file paths.
"""
components = ['model.pth', 'config.json', 'vocab.json', 'dvae.pth', 'mel_stats.pth']
return {name: os.path.join(local_dir, name) for name in components}
def text_to_speech(
self,
text: str,
speaker_reference_wav_path: Optional[str] = None,
temperature: Optional[float] = 0.1,
enable_text_splitting: bool = False
) -> Tuple[int, torch.Tensor]:
"""
Converts text into speech audio.
Args:
text: The input text to be converted into speech.
speaker_reference_wav_path: A path to a reference WAV file for the speaker.
temperature: The temperature parameter for sampling.
enable_text_splitting: Flag to enable or disable text splitting.
Returns:
A tuple containing the sampling rate and the generated audio tensor.
"""
if speaker_reference_wav_path is None:
speaker_reference_wav_path = "./audios/male_2.wav"
self.log("Using default speaker reference ./audios/male_2.wav.")
self.log("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
audio_path=[speaker_reference_wav_path]
)
self.log("Starting inference...")
start_time = time.time()
out = self.model.inference(
text,
self.language_code,
gpt_cond_latent,
speaker_embedding,
temperature=temperature,
enable_text_splitting=enable_text_splitting
)
end_time = time.time()
audio = torch.tensor(out["wav"]).unsqueeze(0).cpu()
sampling_rate = torch.tensor(self.config.model_args.output_sample_rate).cpu().item()
self.log(f"Speech generated in {end_time - start_time:.2f} seconds.")
return sampling_rate, audio
def log(self, message: str):
"""
Logs a message to the console with a uniform format.
Args:
message: The message to be logged.
"""
print(f"[BambaraTTS] {message}")
def log_tokenizer(self):
"""
Logs the tokenizer information.
"""
self.log(f"Tokenizer: {self.model.tokenizer}")