datnth1709 commited on
Commit
468a13d
1 Parent(s): 1b245f0

update source

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ __pycache__
3
+ .idea
4
+ .git
5
+ .vs
6
+ .vscode
7
+ .ipynb_checkpoints
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Realtime Translation
3
- emoji: 📚
4
- colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.3.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: FantasticFour S2T MT Demo
3
+ emoji: 🐠
4
+ colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.3.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import nltk
3
+ import librosa
4
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
5
+
6
+ from transformers import pipeline, TranslationPipeline, AutoTokenizer, TranslationPipeline
7
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Tokenizer
8
+ from transformers.file_utils import cached_path, hf_bucket_url
9
+ import os, zipfile
10
+ from datasets import load_dataset
11
+ import torch
12
+ import kenlm
13
+ import torchaudio
14
+ from pyctcdecode import Alphabet, BeamSearchDecoderCTC, LanguageModel
15
+ device = torch.device(0 if torch.cuda.is_available() else "cpu")
16
+
17
+ """Vietnamese speech2text"""
18
+ cache_dir = './cache/'
19
+ processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir)
20
+ vi_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir)
21
+ lm_file = hf_bucket_url("nguyenvulebinh/wav2vec2-base-vietnamese-250h", filename='vi_lm_4grams.bin.zip')
22
+ lm_file = cached_path(lm_file,cache_dir=cache_dir)
23
+ with zipfile.ZipFile(lm_file, 'r') as zip_ref:
24
+ zip_ref.extractall(cache_dir)
25
+ lm_file = cache_dir + 'vi_lm_4grams.bin'
26
+
27
+ def get_decoder_ngram_model(tokenizer, ngram_lm_path):
28
+ vocab_dict = tokenizer.get_vocab()
29
+ sort_vocab = sorted((value, key) for (key, value) in vocab_dict.items())
30
+ vocab = [x[1] for x in sort_vocab][:-2]
31
+ vocab_list = vocab
32
+ # convert ctc blank character representation
33
+ vocab_list[tokenizer.pad_token_id] = ""
34
+ # replace special characters
35
+ vocab_list[tokenizer.unk_token_id] = ""
36
+ # vocab_list[tokenizer.bos_token_id] = ""
37
+ # vocab_list[tokenizer.eos_token_id] = ""
38
+ # convert space character representation
39
+ vocab_list[tokenizer.word_delimiter_token_id] = " "
40
+ # specify ctc blank char index, since conventially it is the last entry of the logit matrix
41
+ alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=tokenizer.pad_token_id)
42
+ lm_model = kenlm.Model(ngram_lm_path)
43
+ decoder = BeamSearchDecoderCTC(alphabet,
44
+ language_model=LanguageModel(lm_model))
45
+ return decoder
46
+ ngram_lm_model = get_decoder_ngram_model(processor.tokenizer, lm_file)
47
+
48
+ # define function to read in sound file
49
+ def speech_file_to_array_fn(path, max_seconds=10):
50
+ batch = {"file": path}
51
+ speech_array, sampling_rate = torchaudio.load(batch["file"])
52
+ if sampling_rate != 16000:
53
+ transform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
54
+ new_freq=16000)
55
+ speech_array = transform(speech_array)
56
+ speech_array = speech_array[0]
57
+ if max_seconds > 0:
58
+ speech_array = speech_array[:max_seconds*16000]
59
+ batch["speech"] = speech_array.numpy()
60
+ batch["sampling_rate"] = 16000
61
+ return batch
62
+
63
+ # tokenize
64
+ def speech2text_vi(audio):
65
+ # read in sound file
66
+ # load dummy dataset and read soundfiles
67
+ ds = speech_file_to_array_fn(audio.name)
68
+ # infer model
69
+ input_values = processor(
70
+ ds["speech"],
71
+ sampling_rate=ds["sampling_rate"],
72
+ return_tensors="pt"
73
+ ).input_values
74
+ # decode ctc output
75
+ logits = vi_model(input_values).logits[0]
76
+ pred_ids = torch.argmax(logits, dim=-1)
77
+ greedy_search_output = processor.decode(pred_ids)
78
+ beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
79
+ return beam_search_output
80
+
81
+
82
+ """English speech2text"""
83
+ nltk.download("punkt")
84
+ # Loading the model and the tokenizer
85
+ model_name = "facebook/wav2vec2-base-960h"
86
+ eng_tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
87
+ eng_model = Wav2Vec2ForCTC.from_pretrained(model_name)
88
+
89
+ def load_data(input_file):
90
+ """ Function for resampling to ensure that the speech input is sampled at 16KHz.
91
+ """
92
+ # read the file
93
+ speech, sample_rate = librosa.load(input_file)
94
+ # make it 1-D
95
+ if len(speech.shape) > 1:
96
+ speech = speech[:, 0] + speech[:, 1]
97
+ # Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
98
+ if sample_rate != 16000:
99
+ speech = librosa.resample(speech, sample_rate, 16000)
100
+ return speech
101
+
102
+ def correct_casing(input_sentence):
103
+ """ This function is for correcting the casing of the generated transcribed text
104
+ """
105
+ sentences = nltk.sent_tokenize(input_sentence)
106
+ return (' '.join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]))
107
+
108
+
109
+ def speech2text_en(input_file):
110
+ """This function generates transcripts for the provided audio input
111
+ """
112
+ speech = load_data(input_file)
113
+ # Tokenize
114
+ input_values = eng_tokenizer(speech, return_tensors="pt").input_values
115
+ # Take logits
116
+ logits = eng_model(input_values).logits
117
+ # Take argmax
118
+ predicted_ids = torch.argmax(logits, dim=-1)
119
+ # Get the words from predicted word ids
120
+ transcription = eng_tokenizer.decode(predicted_ids[0])
121
+ # Output is all upper case
122
+ transcription = correct_casing(transcription.lower())
123
+ return transcription
124
+
125
+
126
+ """Machine translation"""
127
+ vien_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-vi-en_PhoMT"
128
+ envi_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-en-vi_PhoMT"
129
+ # vien_translator = pipeline("translation", model=vien_model_checkpoint)
130
+ # envi_translator = pipeline("translation", model=envi_model_checkpoint)
131
+
132
+ vien_tokenizer = AutoTokenizer.from_pretrained(vien_model_checkpoint, return_tensors="pt")
133
+ vien_model = ORTModelForSeq2SeqLM.from_pretrained(vien_model_checkpoint)
134
+ vien_translator = TranslationPipeline(model=vien_model, tokenizer=vien_tokenizer,clean_up_tokenization_spaces=True, device=device)
135
+
136
+ envi_tokenizer = AutoTokenizer.from_pretrained(envi_model_checkpoint, return_tensors="pt")
137
+ envi_model = ORTModelForSeq2SeqLM.from_pretrained(envi_model_checkpoint)
138
+ envi_translator = TranslationPipeline(model=envi_model, tokenizer=envi_tokenizer,clean_up_tokenization_spaces=True, device=device)
139
+
140
+
141
+ def translate_vi2en(Vietnamese):
142
+ return vien_translator(Vietnamese)[0]['translation_text']
143
+
144
+ def translate_en2vi(English):
145
+ return envi_translator(English)[0]['translation_text']
146
+
147
+
148
+
149
+
150
+ """ Inference"""
151
+ def inference_vien(audio):
152
+ vi_text = speech2text_vi(audio)
153
+ en_text = translate_vi2en(vi_text)
154
+ return vi_text, en_text
155
+
156
+ def inference_envi(audio):
157
+ en_text = speech2text_en(audio)
158
+ vi_text = translate_en2vi(en_text)
159
+ return en_text, vi_text
160
+
161
+ def transcribe_vi(audio, state_vi="", state_en=""):
162
+ ds = speech_file_to_array_fn(audio.name)
163
+ # infer model
164
+ input_values = processor(
165
+ ds["speech"],
166
+ sampling_rate=ds["sampling_rate"],
167
+ return_tensors="pt"
168
+ ).input_values
169
+ # decode ctc output
170
+ logits = vi_model(input_values).logits[0]
171
+ pred_ids = torch.argmax(logits, dim=-1)
172
+ greedy_search_output = processor.decode(pred_ids)
173
+ beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
174
+ state_vi += beam_search_output + " "
175
+ en_text = translate_vi2en(beam_search_output)
176
+ state_en += en_text + " "
177
+ return state_vi, state_en
178
+
179
+ def transcribe_en(audio, state_en="", state_vi=""):
180
+ speech = load_data(audio)
181
+ # Tokenize
182
+ input_values = eng_tokenizer(speech, return_tensors="pt").input_values
183
+ # Take logits
184
+ logits = eng_model(input_values).logits
185
+ # Take argmax
186
+ predicted_ids = torch.argmax(logits, dim=-1)
187
+ # Get the words from predicted word ids
188
+ transcription = eng_tokenizer.decode(predicted_ids[0])
189
+ # Output is all upper case
190
+ transcription = correct_casing(transcription.lower())
191
+ state_en += transcription + "+"
192
+ vi_text = translate_en2vi(transcription)
193
+ state_vi += vi_text + "+"
194
+ return state_en, state_vi
195
+
196
+ def transcribe_vi_1(audio, state_en=""):
197
+ ds = speech_file_to_array_fn(audio.name)
198
+ # infer model
199
+ input_values = processor(
200
+ ds["speech"],
201
+ sampling_rate=ds["sampling_rate"],
202
+ return_tensors="pt"
203
+ ).input_values
204
+ # decode ctc output
205
+ logits = vi_model(input_values).logits[0]
206
+ pred_ids = torch.argmax(logits, dim=-1)
207
+ greedy_search_output = processor.decode(pred_ids)
208
+ beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
209
+ en_text = translate_vi2en(beam_search_output)
210
+ state_en += en_text + " "
211
+ return state_en, state_en
212
+
213
+ def transcribe_en_1(audio, state_vi=""):
214
+ speech = load_data(audio)
215
+ # Tokenize
216
+ input_values = eng_tokenizer(speech, return_tensors="pt").input_values
217
+ # Take logits
218
+ logits = eng_model(input_values).logits
219
+ # Take argmax
220
+ predicted_ids = torch.argmax(logits, dim=-1)
221
+ # Get the words from predicted word ids
222
+ transcription = eng_tokenizer.decode(predicted_ids[0])
223
+ # Output is all upper case
224
+ transcription = correct_casing(transcription.lower())
225
+ vi_text = translate_en2vi(transcription)
226
+ state_vi += vi_text + "+"
227
+ return state_vi, state_vi
228
+
229
+ """Gradio demo"""
230
+
231
+ vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
232
+ "Ánh mắt ta chạm nhau. Chỉ muốn ngắm anh lâu thật lâu.",
233
+ "Nếu như một câu nói có thể khiến em vui."]
234
+ vi_example_voice =[['vi_speech_01.wav'], ['vi_speech_02.wav'], ['vi_speech_03.wav']]
235
+
236
+ en_example_text = ["According to a study by Statista, the global AI market is set to grow up to 54 percent every single year.",
237
+ "As one of the world's greatest cities, Air New Zealand is proud to add the Big Apple to its list of 29 international destinations.",
238
+ "And yet, earlier this month, I found myself at Halloween Horror Nights at Universal Orlando Resort, one of the most popular Halloween events in the US among hardcore horror buffs."
239
+ ]
240
+ en_example_voice =[['en_speech_01.wav'], ['en_speech_02.wav'], ['en_speech_03.wav']]
241
+
242
+
243
+ with gr.Blocks() as demo:
244
+ with gr.Tabs():
245
+ with gr.TabItem("Vi-En Realtime Translation"):
246
+ gr.Interface(
247
+ fn=transcribe_vi_1,
248
+ inputs=[
249
+ gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True),
250
+ "state",
251
+ ],
252
+ outputs= [
253
+ "text",
254
+ "state",
255
+
256
+ ],
257
+ examples=vi_example_voice,
258
+ live=True).launch()
259
+
260
+
261
+ with gr.Tabs():
262
+ with gr.TabItem("En-Vi Realtime Translation"):
263
+ gr.Interface(
264
+ fn=transcribe_en_1,
265
+ inputs=[
266
+ gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True),
267
+ "state",
268
+ ],
269
+ outputs= [
270
+ "text",
271
+ "state",
272
+
273
+ ],
274
+ examples=en_example_voice,
275
+ live=True).launch()
276
+
277
+ if __name__ == "__main__":
278
+ demo.launch()
en_speech_01.wav ADDED
Binary file (816 kB). View file
 
en_speech_02.wav ADDED
Binary file (238 kB). View file
 
en_speech_03.wav ADDED
Binary file (751 kB). View file
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.9.0
2
+ torchaudio==0.9.0
3
+ transformers==4.9.2
4
+ transformers[sentencepiece]
5
+ datasets==1.11.0
6
+ pyctcdecode==v0.1.0
7
+ speechbrain
8
+ pydub
9
+ kenlm
10
+ pyctcdecode
11
+ soundfile
12
+ ffmpeg-python
13
+ gradio
14
+ nltk
15
+ librosa
16
+ https://github.com/kpu/kenlm/archive/master.zip
vi_speech_01.wav ADDED
Binary file (120 kB). View file
 
vi_speech_02.wav ADDED
Binary file (49.6 kB). View file
 
vi_speech_03.wav ADDED
Binary file (76.8 kB). View file