import torch import os import logging import soundfile as sf import numpy as np from huggingface_hub import hf_hub_download from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts # --- CONSTANTES --- REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference" LOCAL_DIR = "./models" class WolofXTTSInference: def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR): # Configuration du logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) # Créer le dossier local s'il n'existe pas os.makedirs(local_dir, exist_ok=True) # Téléchargement des fichiers nécessaires try: # Créer les sous-dossiers nécessaires os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True) os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True) # Télécharger le checkpoint self.model_path = hf_hub_download( repo_id=repo_id, filename="Anta_GPT_XTTS_Wo/best_model_89250.pth", local_dir=local_dir ) # Télécharger le fichier de configuration self.config_path = hf_hub_download( repo_id=repo_id, filename="Anta_GPT_XTTS_Wo/config.json", local_dir=local_dir ) # Télécharger le vocabulaire self.vocab_path = hf_hub_download( repo_id=repo_id, filename="XTTS_v2.0_original_model_files/vocab.json", local_dir=local_dir ) # Télécharger l'audio de référence self.reference_audio = hf_hub_download( repo_id=repo_id, filename="anta_sample.wav", local_dir=local_dir ) except Exception as e: self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}") raise # Sélection du device self.device = "cuda:0" if torch.cuda.is_available() else "cpu" # Initialisation du modèle self.model = self._load_model() def _load_model(self): """Charge le modèle XTTS""" try: self.logger.info("Chargement du modèle XTTS...") # Initialisation du modèle config = XttsConfig() config.load_json(self.config_path) model = Xtts.init_from_config(config) # Chargement du checkpoint avec load_checkpoint model.load_checkpoint(config, checkpoint_path=self.model_path, vocab_path=self.vocab_path, use_deepspeed=False ) model.to(self.device) model.eval() # Mettre le modèle en mode évaluation self.logger.info("Modèle chargé avec succès!") return model except Exception as e: self.logger.error(f"Erreur lors du chargement du modèle : {e}") raise def generate_audio( self, text: str, reference_audio: str = None, speed: float = 1.06, language: str = "wo", output_path: str = None ) -> tuple[np.ndarray, int]: """ Génère de l'audio à partir du texte fourni Args: text (str): Texte à convertir en audio reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None. speed (float, optional): Vitesse de lecture. Defaults to 1.06. language (str, optional): Langue du texte. Defaults to "wo". output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. Returns: tuple[np.ndarray, int]: audio_array, sample_rate """ if not text: raise ValueError("Le texte ne peut pas être vide.") try: # Utiliser l'audio de référence fourni ou par défaut ref_audio = reference_audio or self.reference_audio # Obtenir les embeddings gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( audio_path=[ref_audio], gpt_cond_len=self.model.config.gpt_cond_len, max_ref_length=self.model.config.max_ref_len, sound_norm_refs=self.model.config.sound_norm_refs ) # Génération de l'audio result = self.model.inference( text=text.lower(), gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, do_sample=False, speed=speed, language=language, enable_text_splitting=True ) # Récupérer le taux d'échantillonnage sample_rate = self.model.config.audio.sample_rate # Sauvegarde optionnelle if output_path: sf.write(output_path, result["wav"], sample_rate) self.logger.info(f"Audio sauvegardé dans {output_path}") return result["wav"], sample_rate except Exception as e: self.logger.error(f"Erreur lors de la génération de l'audio : {e}") raise def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]: """ Génère de l'audio à partir du texte et d'un dictionnaire de configuration. Args: text (str): Texte à convertir en audio config (dict): Dictionnaire de configuration (speed, language, reference_audio) output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. Returns: tuple[np.ndarray, int]: audio_array, sample_rate """ speed = config.get('speed', 1.06) language = config.get('language', "wo") reference_audio = config.get('reference_audio', None) return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path) # Exemple d'utilisation if __name__ == "__main__": tts = WolofXTTSInference() # Exemple de génération d'audio text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!" # Simple audio, sr = tts.generate_audio( text, output_path="generated_audio.wav" ) # Avec une config config_gen_audio = { "speed": 1.2, "language": "wo", } audio, sr = tts.generate_audio_from_config( text=text, config=config_gen_audio, output_path="generated_audio_config.wav" )