yuangongfdu commited on
Commit
73b127f
1 Parent(s): a24b835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -6,8 +6,12 @@ text = "[Github]"
6
  paper_link = "https://arxiv.org/pdf/2307.03183.pdf"
7
  paper_text = "[Paper]"
8
 
9
- model = whisper.load_model('large-v1')
10
- print('model loaded')
 
 
 
 
11
 
12
  def round_time_resolution(time_resolution):
13
  multiple = float(time_resolution) / 0.4
@@ -15,7 +19,7 @@ def round_time_resolution(time_resolution):
15
  rounded_time_resolution = rounded_multiple * 0.4
16
  return rounded_time_resolution
17
 
18
- def predict(audio_path_m, audio_path_t, time_resolution):
19
  # print(audio_path_m, audio_path_t)
20
  # print(type(audio_path_m), type(audio_path_t))
21
  #return audio_path_m, audio_path_t
@@ -24,6 +28,7 @@ def predict(audio_path_m, audio_path_t, time_resolution):
24
  else:
25
  audio_path = audio_path_m or audio_path_t
26
  audio_tagging_time_resolution = round_time_resolution(time_resolution)
 
27
  result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
28
  audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
29
  asr_output = ""
@@ -32,15 +37,17 @@ def predict(audio_path_m, audio_path_t, time_resolution):
32
  at_output = ""
33
  for segment in audio_tag_result:
34
  print(segment)
35
- at_output = at_output + format(segment['time']['start'], ".1f") + 's-' + format(segment['time']['end'], ".1f") + 's: ' + ','.join([x[0] for x in segment['audio tags']]) + '\n'
36
  print(at_output)
37
  return asr_output, at_output
38
 
39
  iface = gr.Interface(fn=predict,
40
- inputs=[gr.Audio(type="filepath", source='microphone', label='Please either upload an audio file or record using the microphone.', show_label=True), gr.Audio(type="filepath"), gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')],
 
 
41
  outputs=[gr.Textbox(label="Speech Output"), gr.Textbox(label="Audio Tag Output")],
42
  cache_examples=True,
43
  title="Quick Demo of Whisper-AT",
44
  description="We are glad to introduce Whisper-AT - A new joint audio tagging and speech recognition model. It outputs background sound labels in addition to text." + f"<a href='{paper_link}'>{paper_text}</a> " + f"<a href='{link}'>{text}</a> <br>" +
45
  "Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab). It is an Interspeech 2023 paper.")
46
- iface.launch(debug=True)
 
6
  paper_link = "https://arxiv.org/pdf/2307.03183.pdf"
7
  paper_text = "[Paper]"
8
 
9
+ model_large = whisper.load_model("large-v1")
10
+ model_tiny = whisper.load_model("tiny")
11
+ model_tiny_en = whisper.load_model("tiny.en")
12
+ model_small = whisper.load_model("small")
13
+
14
+ mdl_dict = {"tiny": model_tiny, "tiny.en": model_tiny_en, "small": model_small, "large": model_large}
15
 
16
  def round_time_resolution(time_resolution):
17
  multiple = float(time_resolution) / 0.4
 
19
  rounded_time_resolution = rounded_multiple * 0.4
20
  return rounded_time_resolution
21
 
22
+ def predict(audio_path_m, audio_path_t, model_size, time_resolution):
23
  # print(audio_path_m, audio_path_t)
24
  # print(type(audio_path_m), type(audio_path_t))
25
  #return audio_path_m, audio_path_t
 
28
  else:
29
  audio_path = audio_path_m or audio_path_t
30
  audio_tagging_time_resolution = round_time_resolution(time_resolution)
31
+ model = mdl_dict[model_size]
32
  result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
33
  audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
34
  asr_output = ""
 
37
  at_output = ""
38
  for segment in audio_tag_result:
39
  print(segment)
40
+ at_output = at_output + format(segment['time']['start'], ".1f") + 's-' + format(segment['time']['end'], ".1f") + 's: ' + ', '.join([x[0] for x in segment['audio tags']]) + '\n'
41
  print(at_output)
42
  return asr_output, at_output
43
 
44
  iface = gr.Interface(fn=predict,
45
+ inputs=[gr.Audio(type="filepath", source='microphone', label='Please either upload an audio file or record using the microphone.', show_label=True), gr.Audio(type="filepath"),
46
+ gr.Radio(["tiny", "tiny.en", "small", "large"], value='large', label="Model size", info="The larger the model, the better the performance and the slower the speed."),
47
+ gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')],
48
  outputs=[gr.Textbox(label="Speech Output"), gr.Textbox(label="Audio Tag Output")],
49
  cache_examples=True,
50
  title="Quick Demo of Whisper-AT",
51
  description="We are glad to introduce Whisper-AT - A new joint audio tagging and speech recognition model. It outputs background sound labels in addition to text." + f"<a href='{paper_link}'>{paper_text}</a> " + f"<a href='{link}'>{text}</a> <br>" +
52
  "Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab). It is an Interspeech 2023 paper.")
53
+ iface.launch(debug=True, share=True)