EricPeter commited on
Commit
3bd8d03
1 Parent(s): 753bd06

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +44 -0
  2. requirements.txt +4 -0
  3. stitched_model.py +27 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import json
5
+ from transformers import pipeline
6
+ from stitched_model import CombinedModel
7
+
8
+
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+
11
+ model = CombinedModel("ak3ra/wav2vec2-sunbird-speech-lug", "Sunbird/sunbird-mul-en-mbart-merged", device="cpu")
12
+
13
+
14
+ def transcribe(audio_file_mic=None, audio_file_upload=None):
15
+ if audio_file_mic:
16
+ audio_file = audio_file_mic
17
+ elif audio_file_upload:
18
+ audio_file = audio_file_upload
19
+ else:
20
+ return "Please upload an audio file or record one"
21
+
22
+ # Make sure audio is 16kHz
23
+ speech, sample_rate = librosa.load(audio_file)
24
+ if sample_rate != 16000:
25
+ speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
26
+ speech = torch.tensor([speech])
27
+
28
+ with torch.no_grad():
29
+ transcription, translation = model({"audio":speech})
30
+
31
+ return transcription, translation[0]
32
+
33
+ description = '''Luganda to English Speech Translation'''
34
+
35
+ iface = gr.Interface(fn=transcribe,
36
+ inputs=[
37
+ gr.Audio(source="microphone", type="filepath", label="Record Audio"),
38
+ gr.Audio(source="upload", type="filepath", label="Upload Audio")],
39
+ outputs=[gr.Textbox(label="Transcription"),
40
+ gr.Textbox(label="Translation")
41
+ ],
42
+ description=description
43
+ )
44
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers[torch]
2
+ librosa
3
+ sentencepiece
4
+ jiwer
stitched_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ class CombinedModel(nn.Module):
6
+ def __init__(self, stt_model_name, nmt_model_name,device = "cuda"):
7
+ super(CombinedModel, self).__init__()
8
+
9
+ self.stt_processor = Wav2Vec2Processor.from_pretrained(stt_model_name)
10
+ self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
11
+ self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
12
+ self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)
13
+ self.device = device
14
+
15
+ def forward(self, batch, *args, **kwargs):
16
+ # Use stt_model to transcribe the audio to text
17
+ device = self.device
18
+ audio = torch.tensor(batch["audio"][0]).to(self.device)
19
+ input_features = self.stt_processor(audio,sampling_rate=16000, return_tensors="pt",max_length=110000, padding=True, truncation=True)
20
+ stt_output = self.stt_model(input_features.input_values.to(device), attention_mask= input_features.attention_mask.to(device) )
21
+ transcription = self.stt_processor.decode(torch.squeeze(stt_output.logits.argmax(axis=-1)).to(device))
22
+ input_nmt_tokens = self.nmt_tokenizer(transcription, return_tensors="pt", padding=True, truncation=True)
23
+ output_nmt_output = self.nmt_model.generate(input_ids = input_nmt_tokens.input_ids.to(device), attention_mask= input_nmt_tokens.attention_mask.to(device))
24
+ decoded_nmt_output = self.nmt_tokenizer.batch_decode(output_nmt_output, skip_special_tokens=True)
25
+
26
+
27
+ return transcription, decoded_nmt_output