rayl-aoit commited on
Commit
d0aa8cd
·
verified ·
1 Parent(s): 219e01a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -27,7 +27,7 @@ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
27
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
28
 
29
  # Function to convert audio to text using ASR
30
- def transcribe(audio_filepath):
31
  if audio_filepath is None:
32
  raise gr.Error("Please provide some input audio.")
33
 
@@ -44,9 +44,9 @@ def transcribe(audio_filepath):
44
  duration = len(data) / SAMPLE_RATE
45
  manifest_data = {
46
  "audio_filepath": converted_audio_filepath,
47
- "taskname": "asr",
48
  "source_lang": "en",
49
- "target_lang": "en",
50
  "pnc": "no",
51
  "answer": "predict",
52
  "duration": str(duration),
@@ -56,9 +56,9 @@ def transcribe(audio_filepath):
56
  fout.write(json.dumps(manifest_data))
57
 
58
  if duration < 40:
59
- transcription = canary_model.transcribe(manifest_filepath)[0]
60
  else:
61
- transcription = get_buffered_pred_feat_multitaskAED(
62
  frame_asr,
63
  canary_model.cfg.preprocessor,
64
  model_stride_in_secs,
@@ -66,7 +66,7 @@ def transcribe(audio_filepath):
66
  manifest=manifest_filepath,
67
  )[0].text
68
 
69
- return transcription
70
 
71
  # Function to convert text to speech using TTS
72
  def gen_speech(text):
@@ -81,9 +81,10 @@ def gen_speech(text):
81
 
82
  # Root function for Gradio interface
83
  def start_process(audio_filepath):
84
- transcription = transcribe(audio_filepath)
85
  print("Done transcribing")
86
- translation = "working in progress"
 
87
  audio_output_filepath = gen_speech(transcription)
88
  print("Done speaking")
89
  return transcription, translation, audio_output_filepath
 
27
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
28
 
29
  # Function to convert audio to text using ASR
30
+ def gen_text(audio_filepath, action):
31
  if audio_filepath is None:
32
  raise gr.Error("Please provide some input audio.")
33
 
 
44
  duration = len(data) / SAMPLE_RATE
45
  manifest_data = {
46
  "audio_filepath": converted_audio_filepath,
47
+ "taskname": action,
48
  "source_lang": "en",
49
+ "target_lang": "en" if action=="asr" else "fr",
50
  "pnc": "no",
51
  "answer": "predict",
52
  "duration": str(duration),
 
56
  fout.write(json.dumps(manifest_data))
57
 
58
  if duration < 40:
59
+ predicted_text = canary_model.transcribe(manifest_filepath)[0]
60
  else:
61
+ predicted_text = get_buffered_pred_feat_multitaskAED(
62
  frame_asr,
63
  canary_model.cfg.preprocessor,
64
  model_stride_in_secs,
 
66
  manifest=manifest_filepath,
67
  )[0].text
68
 
69
+ return predicted_text
70
 
71
  # Function to convert text to speech using TTS
72
  def gen_speech(text):
 
81
 
82
  # Root function for Gradio interface
83
  def start_process(audio_filepath):
84
+ transcription = gen_text(audio_filepath, "asr")
85
  print("Done transcribing")
86
+ translation = gen_text(audio_filepath, "ast")
87
+ print("Done translation")
88
  audio_output_filepath = gen_speech(transcription)
89
  print("Done speaking")
90
  return transcription, translation, audio_output_filepath