dofbi commited on
Commit
e285918
·
verified ·
1 Parent(s): 04ec52a

Add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +71 -0
inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from TTS.tts.configs.xtts_config import XttsConfig
4
+ from TTS.tts.models.xtts import Xtts
5
+ import soundfile as sf
6
+ from removesilence import detect_silence, remove_silence
7
+
8
+ # Chargement du modèle
9
+ def load_model(repo_id):
10
+ #Construction des chemins vers les fichiers
11
+ root_path = "./"
12
+ checkpoint_path = root_path+"Anta_GPT_XTTS_Wo"
13
+ model_path = "best_model_89250.pth"
14
+ xtts_checkpoint = os.path.join(checkpoint_path, model_path)
15
+ xtts_config = os.path.join(checkpoint_path,"config.json")
16
+ xtts_vocab = root_path+"XTTS_v2.0_original_model_files/vocab.json"
17
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+
19
+ # Load model
20
+ config = XttsConfig()
21
+ config.load_json(xtts_config)
22
+ XTTS_MODEL = Xtts.init_from_config(config)
23
+ XTTS_MODEL.load_checkpoint(config,
24
+ checkpoint_path = xtts_checkpoint,
25
+ vocab_path = xtts_vocab,
26
+ use_deepspeed = False)
27
+ XTTS_MODEL.to(device)
28
+ print("Model loaded successfully!")
29
+ return XTTS_MODEL, device
30
+
31
+ # Fonction principale d'inférence
32
+ def inference(text, reference_audio, model, device):
33
+ # Prétraitement des données
34
+ reference = reference_audio
35
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
36
+ audio_path = [reference],
37
+ gpt_cond_len = model.config.gpt_cond_len,
38
+ max_ref_length = model.config.max_ref_len,
39
+ sound_norm_refs = model.config.sound_norm_refs
40
+ )
41
+
42
+ #Inférence
43
+ result = model.inference(
44
+ text = text.lower(),
45
+ gpt_cond_latent = gpt_cond_latent,
46
+ speaker_embedding = speaker_embedding,
47
+ do_sample = False,
48
+ speed = 1.06,
49
+ language = "wo",
50
+ enable_text_splitting=True
51
+ )
52
+
53
+ #Retour du résultat
54
+ sample_rate = model.config.audio.sample_rate
55
+ return result["wav"], sample_rate
56
+
57
+
58
+ # Fonction pour générer l'audio à partir du texte et de l'audio de référence
59
+ def generate_audio(text, reference_audio_path):
60
+ model, device = load_model("dofbi/galsenai-xtts-v2-wolof-inference")
61
+ audio_output, sample_rate = inference(text, reference_audio_path, model, device)
62
+ # Sauvegarde de l'audio temporaire pour le traitement du silence
63
+ temp_audio_path = "temp_audio.wav"
64
+ sf.write(temp_audio_path, audio_output, sample_rate)
65
+ # Post-traitement pour retirer le silence
66
+ lst = detect_silence(temp_audio_path)
67
+ output_audio = "audio_without_silence.wav"
68
+ remove_silence(temp_audio_path, lst, output_audio)
69
+ # Lecture du fichier audio pour le retour
70
+ audio, _ = sf.read(output_audio)
71
+ return audio, sample_rate