vmoras commited on
Commit
a95b578
β€’
1 Parent(s): 22163e2

Initial commit

Browse files
Files changed (6) hide show
  1. .gitignore +8 -0
  2. README.md +1 -1
  3. audio.py +201 -0
  4. main.py +31 -0
  5. model.py +39 -0
  6. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .idea/
3
+ __pycache__/
4
+
5
+ assets
6
+ tts_model
7
+
8
+ output.wav
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ”₯
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.15.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.2.0
8
  app_file: app.py
9
  pinned: false
10
  ---
audio.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import nltk
4
+ import torch
5
+ import pickle
6
+ import torchaudio
7
+ import numpy as np
8
+ import gradio as gr
9
+ from google.cloud import storage
10
+ from TTS.tts.models.xtts import Xtts
11
+ from nltk.tokenize import sent_tokenize
12
+ from huggingface_hub import hf_hub_download
13
+ from TTS.tts.configs.xtts_config import XttsConfig
14
+
15
+
16
+ def _download_starting_files() -> None:
17
+ """
18
+ Downloads the embeddings from a bucket
19
+ """
20
+ os.makedirs('assets', exist_ok=True)
21
+
22
+ # Download credentials file
23
+ hf_hub_download(
24
+ repo_id=os.environ.get('DATA'), repo_type='dataset', filename="credentials.json",
25
+ token=os.environ.get('HUB_TOKEN'), local_dir="assets"
26
+ )
27
+
28
+ # Initialise a client
29
+ credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
30
+ storage_client = storage.Client.from_service_account_json(credentials)
31
+ bucket = storage_client.get_bucket('embeddings-bella')
32
+
33
+ # Get both embeddings
34
+ blob = bucket.blob("gpt_cond_latent.npy")
35
+ blob.download_to_filename('assets/gpt_cond_latent.npy')
36
+ blob = bucket.blob("speaker_embedding.npy")
37
+ blob.download_to_filename('assets/speaker_embedding.npy')
38
+
39
+
40
+ def _load_array(filename):
41
+ """
42
+ Opens a file a returns it, used with numpy files
43
+ """
44
+ with open(filename, 'rb') as f:
45
+ return pickle.load(f)
46
+
47
+
48
+ # Get embeddings
49
+ _download_starting_files()
50
+ os.environ['COQUI_TOS_AGREED'] = '1'
51
+
52
+ # Used to generate audio based on a sample
53
+ nltk.download('punkt')
54
+ model_path = os.path.join("tts_model")
55
+
56
+ config = XttsConfig()
57
+ config.load_json(os.path.join(model_path, "config.json"))
58
+
59
+ model = Xtts.init_from_config(config)
60
+ model.load_checkpoint(
61
+ config,
62
+ checkpoint_path=os.path.join(model_path, "model.pth"),
63
+ vocab_path=os.path.join(model_path, "vocab.json"),
64
+ eval=True,
65
+ use_deepspeed=True,
66
+ )
67
+
68
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ model.to(device)
70
+
71
+ # Speaker latent
72
+ path_latents = 'assets/gpt_cond_latent.npy'
73
+ gpt_cond_latent = _load_array(path_latents)
74
+
75
+ # Speaker embedding
76
+ path_embedding = 'assets/speaker_embedding.npy'
77
+ speaker_embedding = _load_array(path_embedding)
78
+
79
+
80
+ def get_audio(text: str, language: str = 'es') -> gr.Audio:
81
+ """
82
+ Returns a link from a bucket in GCP that contains the generated audio given a text and language and the
83
+ name of such audio
84
+ :param text: used to generate the audio
85
+ :param language: 'es', 'en' or 'pt'
86
+ :return link_audio and name_audio
87
+ """
88
+ # Creates an audio with the answer and saves it as output.wav
89
+ _save_audio(text, language)
90
+
91
+ return gr.Audio(value='output.wav', interactive=False, visible=True)
92
+
93
+
94
+ def _save_audio(answer: str, language: str) -> None:
95
+ """
96
+ Splits the answer into sentences, clean and creates an audio for each one, then concatenates
97
+ all the audios and saves them into a file (output.wav)
98
+ """
99
+ # Split the answer into sentences and clean it
100
+ sentences = _get_clean_answer(answer, language)
101
+
102
+ # Get the voice of each sentence
103
+ audio_segments = []
104
+ for sentence in sentences:
105
+ audio_stream = _get_voice(sentence, language)
106
+ audio_stream = torch.tensor(audio_stream)
107
+ audio_segments.append(audio_stream)
108
+
109
+ # Concatenate and save all audio segments
110
+ concatenated_audio = torch.cat(audio_segments, dim=0)
111
+ torchaudio.save('output.wav', concatenated_audio.unsqueeze(0), 24000)
112
+
113
+
114
+ def _get_voice(sentence: str, language: str) -> np.ndarray:
115
+ """
116
+ Returns a numpy array with a wav of an audio with the given sentence and language
117
+ """
118
+ out = model.inference(
119
+ sentence,
120
+ language=language,
121
+ gpt_cond_latent=gpt_cond_latent,
122
+ speaker_embedding=speaker_embedding,
123
+ temperature=0.1
124
+ )
125
+ return out['wav']
126
+
127
+
128
+ def _get_clean_answer(answer: str, language: str) -> list[str]:
129
+ """
130
+ Returns a list of sentences of the answer. It also removes links
131
+ """
132
+ # Remove the links in the audio and add another sentence
133
+ if language == 'en':
134
+ clean_answer = re.sub(r'http[s]?://\S+', 'the following link', answer)
135
+ max_characters = 250
136
+ elif language == 'es':
137
+ clean_answer = re.sub(r'http[s]?://\S+', 'el siguiente link', answer)
138
+ max_characters = 239
139
+ else:
140
+ clean_answer = re.sub(r'http[s]?://\S+', 'o seguinte link', answer)
141
+ max_characters = 203
142
+
143
+ # Change the name from Bella to Bela
144
+ clean_answer = clean_answer.replace('Bella', 'Bela')
145
+
146
+ # Remove Florida and zipcode
147
+ clean_answer = re.sub(r', FL \d+', "", clean_answer)
148
+
149
+ # Split the answer into sentences with nltk and make sure they are shorter than the maximum possible
150
+ # characters
151
+ split_sentences = sent_tokenize(clean_answer)
152
+ sentences = []
153
+ for sentence in split_sentences:
154
+ if len(sentence) > max_characters:
155
+ sentences.extend(split_sentence(sentence, max_characters))
156
+ else:
157
+ sentences.append(sentence)
158
+
159
+ return sentences
160
+
161
+
162
+ def split_sentence(sentence: str, max_characters: int) -> list[str]:
163
+ """
164
+ Returns a split sentences. The split point is the nearest comma to the middle
165
+ of the sentence, if there is no comma then a space is used or just the middle. If the
166
+ remaining sentences are still too long, another iteration is run
167
+ """
168
+ # Get index of each comma
169
+ sentences = []
170
+ commas = [i for i, c in enumerate(sentence) if c == ',']
171
+
172
+ # No commas, search for spaces
173
+ if len(commas) == 0:
174
+ commas = [i for i, c in enumerate(sentence) if c == ' ']
175
+
176
+ # No commas or spaces, split it in the middle
177
+ if len(commas) == 0:
178
+ sentences.append(sentence[:len(sentence) // 2])
179
+ sentences.append(sentence[len(sentence) // 2:])
180
+ return sentences
181
+
182
+ # Nearest index to the middle
183
+ split_point = min(commas, key=lambda x: abs(x - (len(sentence) // 2)))
184
+
185
+ if sentence[split_point] == ',':
186
+ left = sentence[:split_point]
187
+ right = sentence[split_point + 2:]
188
+ else:
189
+ left = sentence[:split_point]
190
+ right = sentence[split_point + 1:]
191
+
192
+ if len(left) > max_characters:
193
+ sentences.extend(split_sentence(left, max_characters))
194
+ else:
195
+ sentences.append(left)
196
+ if len(right) > max_characters:
197
+ sentences.extend(split_sentence(right, max_characters))
198
+ else:
199
+ sentences.append(right)
200
+
201
+ return sentences
main.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ load_dotenv()
4
+
5
+ import model
6
+ # Get TTS model
7
+ if not os.path.exists('tts_model'):
8
+ model.download_model()
9
+
10
+ import audio
11
+ import gradio as gr
12
+
13
+
14
+ def update_widget():
15
+ return gr.Button(value='Creating audio...', interactive=False)
16
+
17
+
18
+ with gr.Blocks() as app:
19
+ text = gr.Textbox(label="Text")
20
+ button = gr.Button(value='Create audio')
21
+
22
+ audio_file = gr.Audio(visible=False)
23
+
24
+ button.click(
25
+ update_widget, None, button
26
+ ).then(
27
+ audio.get_audio, text, audio_file
28
+ )
29
+
30
+ app.queue()
31
+ app.launch(debug=True, auth=(os.environ.get('SPACE_USERNAME'), os.environ.get('SPACE_PASSWORD')))
model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ def _download_file(url, destination):
7
+ response = requests.get(url, stream=True)
8
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
9
+ block_size = 1024
10
+
11
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
12
+
13
+ with open(destination, 'wb') as file:
14
+ for data in response.iter_content(block_size):
15
+ progress_bar.update(len(data))
16
+ file.write(data)
17
+
18
+ progress_bar.close()
19
+
20
+
21
+ def download_model():
22
+ # Define files and their corresponding URLs
23
+ files_to_download = {
24
+ 'LICENSE.txt': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/LICENSE.txt?download=true',
25
+ 'README.md': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/README.md?download=true',
26
+ 'config.json': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/config.json?download=true',
27
+ 'model.pth': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/model.pth?download=true',
28
+ 'vocab.json': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/vocab.json?download=true',
29
+ }
30
+
31
+ if not os.path.exists("tts_model"):
32
+ os.makedirs("tts_model")
33
+
34
+ # Download files if they don't exist
35
+ print("[COQUI TTS] STARTUP: Checking Model is Downloaded.")
36
+ for filename, url in files_to_download.items():
37
+ destination = f'tts_model/{filename}'
38
+ print(f"[COQUI TTS] STARTUP: Downloading {filename}...")
39
+ _download_file(url, destination)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ requests==2.31.0
2
+ tqdm==4.66.1
3
+ nltk==3.8.1
4
+ deepspeed==0.12.3
5
+ torch==2.1.1
6
+ torchaudio==2.1.1
7
+ TTS==0.21.2
8
+ google-cloud-storage==2.13.0
9
+ python-dotenv==1.0.1
10
+ gradio==4.15.0
11
+ numpy==1.22.0
12
+ transformers==4.36.0