Core23 commited on
Commit
ad95864
1 Parent(s): aa3f10f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ import gradio as gr
4
+ from gtts import gTTS
5
+ from transformers import pipeline
6
+
7
+ openai.api_key = os.getenv("OPENAI_API_KEY")
8
+
9
+ pipe = pipeline(model="seeafricatz/kiaziboraasr")
10
+
11
+ def transcribe(audio):
12
+ text = pipe(audio)["text"]
13
+ return text
14
+
15
+ def generate_response(transcribed_text):
16
+ response = openai.ChatCompletion.create(
17
+ model="gpt-3.5-turbo",
18
+ messages=[
19
+ {
20
+ "role": "system",
21
+ "content": "All your answers should be in Swahili only, users understand Swahili only so here we start... Wewe ni mtaalamu wa haki za ardhi za wanawake nchini Kongo na utajibu maswali yote kwa Kiswahili tu!"
22
+ },
23
+ {
24
+ "role": "user",
25
+ "content": "Mambo vipi?"
26
+ },
27
+ {
28
+ "role": "assistant",
29
+ "content": "Salama, je una swali lolote kuhusu haki za ardhi za wanawake nchini Kongo?"
30
+ },
31
+ {
32
+ "role": "user",
33
+ "content": "nini maana ya haki za ardhi za wanawake?"
34
+ },
35
+ {
36
+ "role": "assistant",
37
+ "content": "Haki za ardhi za wanawake zinamaanisha haki za wanawake kumiliki, kutumia, na kudhibiti ardhi. Katika muktadha wa Kongo, haki hizi zinaweza kuathiriwa na mila, sheria, na mizozo ya ardhi."
38
+ },
39
+ {
40
+ "role": "user",
41
+ "content": "nini matumizi ya haki za ardhi za wanawake?"
42
+ },
43
+ {
44
+ "role": "assistant",
45
+ "content": "Haki za ardhi za wanawake zina umuhimu mkubwa kwa kuwawezesha wanawake kiuchumi, kuimarisha usalama wa chakula, na kuchangia katika maendeleo ya jamii na taifa kwa ujumla. Kwa mfano, wanawake wenye haki za ardhi wanaweza kupata mikopo, kuendeleza ardhi, na kutoa mchango muhimu katika uchumi wa familia na jamii."
46
+ },
47
+ {
48
+ "role": "user",
49
+ "content": transcribed_text
50
+ },
51
+ ]
52
+ )
53
+ return response['choices'][0]['message']['content']
54
+
55
+
56
+ import os
57
+ import subprocess
58
+ import locale
59
+ locale.getpreferredencoding = lambda: "UTF-8"
60
+
61
+ def download(lang, tgt_dir="./"):
62
+ lang_fn, lang_dir = os.path.join(tgt_dir, lang+'.tar.gz'), os.path.join(tgt_dir, lang)
63
+ cmd = ";".join([
64
+ f"wget https://dl.fbaipublicfiles.com/mms/tts/{lang}.tar.gz -O {lang_fn}",
65
+ f"tar zxvf {lang_fn}"
66
+ ])
67
+ print(f"Download model for language: {lang}")
68
+ subprocess.check_output(cmd, shell=True)
69
+ print(f"Model checkpoints in {lang_dir}: {os.listdir(lang_dir)}")
70
+ return lang_dir
71
+
72
+ LANG = "swh"
73
+ ckpt_dir = download(LANG)
74
+
75
+ from IPython.display import Audio
76
+ import os
77
+ import re
78
+ import glob
79
+ import json
80
+ import tempfile
81
+ import math
82
+ import torch
83
+ from torch import nn
84
+ from torch.nn import functional as F
85
+ from torch.utils.data import DataLoader
86
+ import numpy as np
87
+ import commons
88
+ import utils
89
+ import argparse
90
+ import subprocess
91
+ from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
92
+ from models import SynthesizerTrn
93
+ from scipy.io.wavfile import write
94
+
95
+ def preprocess_char(text, lang=None):
96
+ """
97
+ Special treatement of characters in certain languages
98
+ """
99
+ print(lang)
100
+ if lang == 'ron':
101
+ text = text.replace("ț", "ţ")
102
+ return text
103
+
104
+ class TextMapper(object):
105
+ def __init__(self, vocab_file):
106
+ self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()]
107
+ self.SPACE_ID = self.symbols.index(" ")
108
+ self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
109
+ self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
110
+
111
+ def text_to_sequence(self, text, cleaner_names):
112
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
113
+ Args:
114
+ text: string to convert to a sequence
115
+ cleaner_names: names of the cleaner functions to run the text through
116
+ Returns:
117
+ List of integers corresponding to the symbols in the text
118
+ '''
119
+ sequence = []
120
+ clean_text = text.strip()
121
+ for symbol in clean_text:
122
+ symbol_id = self._symbol_to_id[symbol]
123
+ sequence += [symbol_id]
124
+ return sequence
125
+
126
+ def uromanize(self, text, uroman_pl):
127
+ iso = "xxx"
128
+ with tempfile.NamedTemporaryFile() as tf, \
129
+ tempfile.NamedTemporaryFile() as tf2:
130
+ with open(tf.name, "w") as f:
131
+ f.write("\n".join([text]))
132
+ cmd = f"perl " + uroman_pl
133
+ cmd += f" -l {iso} "
134
+ cmd += f" < {tf.name} > {tf2.name}"
135
+ os.system(cmd)
136
+ outtexts = []
137
+ with open(tf2.name) as f:
138
+ for line in f:
139
+ line = re.sub(r"\s+", " ", line).strip()
140
+ outtexts.append(line)
141
+ outtext = outtexts[0]
142
+ return outtext
143
+
144
+ def get_text(self, text, hps):
145
+ text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
146
+ if hps.data.add_blank:
147
+ text_norm = commons.intersperse(text_norm, 0)
148
+ text_norm = torch.LongTensor(text_norm)
149
+ return text_norm
150
+
151
+ def filter_oov(self, text):
152
+ val_chars = self._symbol_to_id
153
+ txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
154
+ print(f"text after filtering OOV: {txt_filt}")
155
+ return txt_filt
156
+
157
+ def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None):
158
+ txt = preprocess_char(txt, lang=lang)
159
+ is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
160
+ if is_uroman:
161
+ with tempfile.TemporaryDirectory() as tmp_dir:
162
+ if uroman_dir is None:
163
+ cmd = f"git clone git@github.com:isi-nlp/uroman.git {tmp_dir}"
164
+ print(cmd)
165
+ subprocess.check_output(cmd, shell=True)
166
+ uroman_dir = tmp_dir
167
+ uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
168
+ print(f"uromanize")
169
+ txt = text_mapper.uromanize(txt, uroman_pl)
170
+ print(f"uroman text: {txt}")
171
+ txt = txt.lower()
172
+ txt = text_mapper.filter_oov(txt)
173
+ return txt
174
+
175
+ if torch.cuda.is_available():
176
+ device = torch.device("cuda")
177
+ else:
178
+ device = torch.device("cpu")
179
+
180
+ print(f"Run inference with {device}")
181
+ vocab_file = f"{ckpt_dir}/vocab.txt"
182
+ config_file = f"{ckpt_dir}/config.json"
183
+ assert os.path.isfile(config_file), f"{config_file} doesn't exist"
184
+ hps = utils.get_hparams_from_file(config_file)
185
+ text_mapper = TextMapper(vocab_file)
186
+ net_g = SynthesizerTrn(
187
+ len(text_mapper.symbols),
188
+ hps.data.filter_length // 2 + 1,
189
+ hps.train.segment_size // hps.data.hop_length,
190
+ **hps.model)
191
+ net_g.to(device)
192
+ _ = net_g.eval()
193
+
194
+ g_pth = f"{ckpt_dir}/G_100000.pth"
195
+ print(f"load {g_pth}")
196
+
197
+ _ = utils.load_checkpoint(g_pth, net_g, None)
198
+
199
+
200
+ If you want to use the original text-to-speech code in place of the gTTS library within the inference function, you should move the original text-to-speech code into the inference function, and adjust the function to save the generated audio to a file and return the file path. Here's how you might do it:
201
+
202
+ python
203
+ Copy code
204
+ import torch
205
+ from scipy.io.wavfile import write
206
+
207
+ def inference(text):
208
+ # Preprocessing the text
209
+ text = preprocess_text(text, text_mapper, hps, lang=LANG)
210
+ stn_tst = text_mapper.get_text(text, hps)
211
+
212
+ with torch.no_grad():
213
+ x_tst = stn_tst.unsqueeze(0).to(device)
214
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
215
+ hyp = net_g.infer(
216
+ x_tst, x_tst_lengths, noise_scale=.667,
217
+ noise_scale_w=0.8, length_scale=1.0
218
+ )[0][0,0].cpu().float().numpy()
219
+
220
+ # Saving the generated audio to a file
221
+ output_file = "tts_output.wav"
222
+ write(output_file, hps.data.sampling_rate, hyp)
223
+
224
+ return output_file
225
+
226
+ def process_audio_and_respond(audio):
227
+ text = transcribe(audio)
228
+ response_text = generate_response(text)
229
+ output_file = inference(response_text)
230
+ return response_text, output_file
231
+
232
+ demo = gr.Interface(
233
+ process_audio_and_respond,
234
+ gr.inputs.Audio(source="microphone", type="filepath", label="Bonyeza kitufe cha kurekodi na uliza swali lako"),
235
+ [gr.outputs.Textbox(label="Jibu (kwa njia ya maandishi)"), gr.outputs.Audio(type="filepath", label="Jibu kwa njia ya sauti (Bofya kusikiliza Jibu)")],
236
+ title="Haki",
237
+ description="Uliza Swali kuhusu haki za ardhi",
238
+ theme="compact",
239
+ layout="vertical",
240
+ allow_flagging=False,
241
+ live=True,
242
+ )
243
+
244
+ demo.launch()