Update app.py
Browse files
app.py
CHANGED
@@ -1,44 +1,37 @@
|
|
1 |
-
import gradio
|
2 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTextToWaveform
|
3 |
import torch
|
|
|
|
|
4 |
|
5 |
# Load your pretrained models
|
6 |
asr_model = Wav2Vec2ForCTC.from_pretrained("Baghdad99/saad-speech-recognition-hausa-audio-to-text")
|
7 |
asr_processor = Wav2Vec2Processor.from_pretrained("Baghdad99/saad-speech-recognition-hausa-audio-to-text")
|
8 |
-
|
9 |
-
# Load the Hausa translation model
|
10 |
translation_tokenizer = AutoTokenizer.from_pretrained("Baghdad99/saad-hausa-text-to-english-text")
|
11 |
translation_model = AutoModelForSeq2SeqLM.from_pretrained("Baghdad99/saad-hausa-text-to-english-text", from_tf=True)
|
12 |
-
|
13 |
-
# Load the Text-to-Speech model
|
14 |
tts_tokenizer = AutoTokenizer.from_pretrained("Baghdad99/english_voice_tts")
|
15 |
tts_model = AutoModelForTextToWaveform.from_pretrained("Baghdad99/english_voice_tts")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
audio_signal, sample_rate = speech
|
20 |
-
|
21 |
-
# Convert stereo to mono if necessary
|
22 |
-
if len(audio_signal.shape) > 1:
|
23 |
-
audio_signal = audio_signal.mean(axis=0)
|
24 |
-
|
25 |
-
# Transcribe the speech to text
|
26 |
inputs = asr_processor(audio_signal, return_tensors="pt", padding=True)
|
27 |
logits = asr_model(inputs.input_values).logits
|
28 |
predicted_ids = torch.argmax(logits, dim=-1)
|
29 |
transcription = asr_processor.decode(predicted_ids[0])
|
30 |
-
|
31 |
-
# Translate the text
|
32 |
translated = translation_model.generate(**translation_tokenizer(transcription, return_tensors="pt", padding=True))
|
33 |
translated_text = [translation_tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
|
|
34 |
|
35 |
-
|
36 |
inputs = tts_tokenizer(translated_text, return_tensors='pt')
|
37 |
audio = tts_model.generate(inputs['input_ids'])
|
38 |
-
|
39 |
return audio
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# Define the Gradio interface
|
43 |
-
iface =
|
44 |
iface.launch()
|
|
|
1 |
+
import gradio
|
|
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTextToWaveform
|
5 |
|
6 |
# Load your pretrained models
|
7 |
asr_model = Wav2Vec2ForCTC.from_pretrained("Baghdad99/saad-speech-recognition-hausa-audio-to-text")
|
8 |
asr_processor = Wav2Vec2Processor.from_pretrained("Baghdad99/saad-speech-recognition-hausa-audio-to-text")
|
|
|
|
|
9 |
translation_tokenizer = AutoTokenizer.from_pretrained("Baghdad99/saad-hausa-text-to-english-text")
|
10 |
translation_model = AutoModelForSeq2SeqLM.from_pretrained("Baghdad99/saad-hausa-text-to-english-text", from_tf=True)
|
|
|
|
|
11 |
tts_tokenizer = AutoTokenizer.from_pretrained("Baghdad99/english_voice_tts")
|
12 |
tts_model = AutoModelForTextToWaveform.from_pretrained("Baghdad99/english_voice_tts")
|
13 |
|
14 |
+
# Define the translation and synthesis functions
|
15 |
+
def translate(audio_signal):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
inputs = asr_processor(audio_signal, return_tensors="pt", padding=True)
|
17 |
logits = asr_model(inputs.input_values).logits
|
18 |
predicted_ids = torch.argmax(logits, dim=-1)
|
19 |
transcription = asr_processor.decode(predicted_ids[0])
|
|
|
|
|
20 |
translated = translation_model.generate(**translation_tokenizer(transcription, return_tensors="pt", padding=True))
|
21 |
translated_text = [translation_tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
22 |
+
return translated_text
|
23 |
|
24 |
+
def synthesise(translated_text):
|
25 |
inputs = tts_tokenizer(translated_text, return_tensors='pt')
|
26 |
audio = tts_model.generate(inputs['input_ids'])
|
|
|
27 |
return audio
|
28 |
|
29 |
+
def translate_speech(audio):
|
30 |
+
translated_text = translate(audio)
|
31 |
+
synthesised_speech = synthesise(translated_text)
|
32 |
+
synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)
|
33 |
+
return 16000, synthesised_speech
|
34 |
|
35 |
# Define the Gradio interface
|
36 |
+
iface = gradio.Interface(fn=translate_speech, inputs=gradio.inputs.Audio(source="microphone", type="numpy"), outputs="audio")
|
37 |
iface.launch()
|