datnth1709 commited on
Commit
2f12a3f
1 Parent(s): e6ce204

update inference

Browse files
Files changed (2) hide show
  1. README.md +3 -3
  2. app.py +14 -152
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: FantasticFour S2T MT Demo
3
- emoji: 🐠
4
- colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.3.1
 
1
  ---
2
+ title: Realtime S2T MT Demo
3
+ emoji: 🥑
4
+ colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.3.1
app.py CHANGED
@@ -77,65 +77,13 @@ def speech2text_vi(audio):
77
  return beam_search_output
78
 
79
 
80
- """English speech2text"""
81
- nltk.download("punkt")
82
- # Loading the model and the tokenizer
83
- model_name = "facebook/s2t-small-librispeech-asr"
84
- eng_tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
85
- eng_model = Wav2Vec2ForCTC.from_pretrained(model_name)
86
-
87
- def load_data(input_file):
88
- """ Function for resampling to ensure that the speech input is sampled at 16KHz.
89
- """
90
- # read the file
91
- speech, sample_rate = librosa.load(input_file)
92
- # make it 1-D
93
- if len(speech.shape) > 1:
94
- speech = speech[:, 0] + speech[:, 1]
95
- # Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
96
- if sample_rate != 16000:
97
- speech = librosa.resample(speech, sample_rate, 16000)
98
- return speech
99
-
100
- def correct_casing(input_sentence):
101
- """ This function is for correcting the casing of the generated transcribed text
102
- """
103
- sentences = nltk.sent_tokenize(input_sentence)
104
- return (' '.join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]))
105
-
106
-
107
- def speech2text_en(input_file):
108
- """This function generates transcripts for the provided audio input
109
- """
110
- speech = load_data(input_file)
111
- # Tokenize
112
- input_values = eng_tokenizer(speech, return_tensors="pt").input_values
113
- # Take logits
114
- logits = eng_model(input_values).logits
115
- # Take argmax
116
- predicted_ids = torch.argmax(logits, dim=-1)
117
- # Get the words from predicted word ids
118
- transcription = eng_tokenizer.decode(predicted_ids[0])
119
- # Output is all upper case
120
- transcription = correct_casing(transcription.lower())
121
- return transcription
122
-
123
-
124
  """Machine translation"""
125
  vien_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-vi-en_PhoMT"
126
- envi_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-en-vi_PhoMT"
127
  vien_translator = pipeline("translation", model=vien_model_checkpoint)
128
- envi_translator = pipeline("translation", model=envi_model_checkpoint)
129
-
130
 
131
  def translate_vi2en(Vietnamese):
132
  return vien_translator(Vietnamese)[0]['translation_text']
133
 
134
- def translate_en2vi(English):
135
- return envi_translator(English)[0]['translation_text']
136
-
137
-
138
-
139
 
140
  """ Inference"""
141
  def inference_vien(audio):
@@ -143,46 +91,6 @@ def inference_vien(audio):
143
  en_text = translate_vi2en(vi_text)
144
  return vi_text, en_text
145
 
146
- def inference_envi(audio):
147
- en_text = speech2text_en(audio)
148
- vi_text = translate_en2vi(en_text)
149
- return en_text, vi_text
150
-
151
- def transcribe_vi(audio, state_vi="", state_en=""):
152
- ds = speech_file_to_array_fn(audio.name)
153
- # infer model
154
- input_values = processor(
155
- ds["speech"],
156
- sampling_rate=ds["sampling_rate"],
157
- return_tensors="pt"
158
- ).input_values
159
- # decode ctc output
160
- logits = vi_model(input_values).logits[0]
161
- pred_ids = torch.argmax(logits, dim=-1)
162
- greedy_search_output = processor.decode(pred_ids)
163
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
164
- state_vi += beam_search_output + " "
165
- en_text = translate_vi2en(beam_search_output)
166
- state_en += en_text + " "
167
- return state_vi, state_en
168
-
169
- def transcribe_en(audio, state_en="", state_vi=""):
170
- speech = load_data(audio)
171
- # Tokenize
172
- input_values = eng_tokenizer(speech, return_tensors="pt").input_values
173
- # Take logits
174
- logits = eng_model(input_values).logits
175
- # Take argmax
176
- predicted_ids = torch.argmax(logits, dim=-1)
177
- # Get the words from predicted word ids
178
- transcription = eng_tokenizer.decode(predicted_ids[0])
179
- # Output is all upper case
180
- transcription = correct_casing(transcription.lower())
181
- state_en += transcription + "+"
182
- vi_text = translate_en2vi(transcription)
183
- state_vi += vi_text + "+"
184
- return state_en, state_vi
185
-
186
  def transcribe_vi_1(audio, state_en=""):
187
  ds = speech_file_to_array_fn(audio.name)
188
  # infer model
@@ -200,69 +108,23 @@ def transcribe_vi_1(audio, state_en=""):
200
  state_en += en_text + " "
201
  return state_en, state_en
202
 
203
- def transcribe_en_1(audio, state_vi=""):
204
- speech = load_data(audio)
205
- # Tokenize
206
- input_values = eng_tokenizer(speech, return_tensors="pt").input_values
207
- # Take logits
208
- logits = eng_model(input_values).logits
209
- # Take argmax
210
- predicted_ids = torch.argmax(logits, dim=-1)
211
- # Get the words from predicted word ids
212
- transcription = eng_tokenizer.decode(predicted_ids[0])
213
- # Output is all upper case
214
- transcription = correct_casing(transcription.lower())
215
- vi_text = translate_en2vi(transcription)
216
- state_vi += vi_text + "+"
217
- return state_vi, state_vi
218
-
219
  """Gradio demo"""
220
-
221
  vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
222
  "Ánh mắt ta chạm nhau. Chỉ muốn ngắm anh lâu thật lâu.",
223
  "Nếu như một câu nói có thể khiến em vui."]
224
  vi_example_voice =[['vi_speech_01.wav'], ['vi_speech_02.wav'], ['vi_speech_03.wav']]
225
 
226
- en_example_text = ["According to a study by Statista, the global AI market is set to grow up to 54 percent every single year.",
227
- "As one of the world's greatest cities, Air New Zealand is proud to add the Big Apple to its list of 29 international destinations.",
228
- "And yet, earlier this month, I found myself at Halloween Horror Nights at Universal Orlando Resort, one of the most popular Halloween events in the US among hardcore horror buffs."
229
- ]
230
- en_example_voice =[['en_speech_01.wav'], ['en_speech_02.wav'], ['en_speech_03.wav']]
231
-
232
-
233
- with gr.Blocks() as demo:
234
- with gr.Tabs():
235
- with gr.TabItem("Vi-En Realtime Translation"):
236
- gr.Interface(
237
- fn=transcribe_vi_1,
238
- inputs=[
239
- gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True),
240
- "state",
241
- ],
242
- outputs= [
243
- "text",
244
- "state",
245
-
246
- ],
247
- examples=vi_example_voice,
248
- live=True).launch()
249
-
250
-
251
- with gr.Tabs():
252
- with gr.TabItem("En-Vi Realtime Translation"):
253
- gr.Interface(
254
- fn=transcribe_en_1,
255
- inputs=[
256
- gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True),
257
- "state",
258
- ],
259
- outputs= [
260
- "text",
261
- "state",
262
-
263
- ],
264
- examples=en_example_voice,
265
- live=True).launch()
266
-
267
- if __name__ == "__main__":
268
- demo.launch()
 
77
  return beam_search_output
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  """Machine translation"""
81
  vien_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-vi-en_PhoMT"
 
82
  vien_translator = pipeline("translation", model=vien_model_checkpoint)
 
 
83
 
84
  def translate_vi2en(Vietnamese):
85
  return vien_translator(Vietnamese)[0]['translation_text']
86
 
 
 
 
 
 
87
 
88
  """ Inference"""
89
  def inference_vien(audio):
 
91
  en_text = translate_vi2en(vi_text)
92
  return vi_text, en_text
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def transcribe_vi_1(audio, state_en=""):
95
  ds = speech_file_to_array_fn(audio.name)
96
  # infer model
 
108
  state_en += en_text + " "
109
  return state_en, state_en
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  """Gradio demo"""
 
112
  vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
113
  "Ánh mắt ta chạm nhau. Chỉ muốn ngắm anh lâu thật lâu.",
114
  "Nếu như một câu nói có thể khiến em vui."]
115
  vi_example_voice =[['vi_speech_01.wav'], ['vi_speech_02.wav'], ['vi_speech_03.wav']]
116
 
117
+ with gr.TabItem("Vi-En Realtime Translation"):
118
+ gr.Interface(
119
+ fn=transcribe_vi_1,
120
+ inputs=[
121
+ gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True),
122
+ "state",
123
+ ],
124
+ outputs= [
125
+ "text",
126
+ "state",
127
+
128
+ ],
129
+ examples=vi_example_voice,
130
+ live=True).launch()