MorenoLQ commited on
Commit
7fce27b
1 Parent(s): 856bef6

Updated for file upload and missing inputs

Browse files
Files changed (3) hide show
  1. app.py +66 -45
  2. demo_example_1.mp3 +0 -0
  3. gradio_queue.db +0 -0
app.py CHANGED
@@ -25,61 +25,78 @@ DICT_MODELS = {
25
  MODELS = sorted(DICT_MODELS.keys())
26
  CACHED_MODELS_BY_ID = {}
27
 
28
- def run(input_file, model_name, decoding_type, history):
29
-
30
- logger.info(f"Running ASR {model_name}-{decoding_type} for {input_file}")
31
-
32
- history = history or []
 
 
 
 
 
 
 
 
33
 
 
 
34
  model = DICT_MODELS.get(model_name)
35
-
36
- if model is None:
37
- history.append({
38
- "error_message": f"Model size {model_size} not found for {language} language :("
39
- })
40
- elif decoding_type == "Guided by Language Model" and not model["has_lm"]:
41
  history.append({
42
- "error_message": f"LM not available for {language} language :("
 
 
 
43
  })
44
  else:
45
 
46
- # model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
47
- model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
48
- if model_instance is None:
49
- model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
50
- CACHED_MODELS_BY_ID[model["model_id"]] = model_instance
51
 
52
- if decoding_type == "Guided by Language Model":
53
- processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
54
- asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
55
- feature_extractor=processor.feature_extractor, decoder=processor.decoder)
 
 
 
 
 
 
 
 
56
  else:
57
- processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
58
- asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
59
- feature_extractor=processor.feature_extractor, decoder=None)
60
 
61
- transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]
 
 
 
 
62
 
63
- logger.info(f"Transcription for {input_file}: {transcription}")
 
 
 
 
 
 
 
64
 
65
- history.append({
66
- "model_id": model["model_id"],
67
- "decoding_type": decoding_type,
68
- "transcription": transcription,
69
- "error_message": None
70
- })
71
 
72
- html_output = "<div class='result'>"
73
- for item in history:
74
- if item["error_message"] is not None:
75
- html_output += f"<div class='result_item result_item_error'>{item['error_message']}</div>"
76
- else:
77
- url_suffix = " + Guided by Language Model" if item["decoding_type"] == "Guided by Language Model" else ""
78
- html_output += "<div class='result_item result_item_success'>"
79
- html_output += f'<strong><a target="_blank" href="https://huggingface.co/{item["model_id"]}">{item["model_id"]}{url_suffix}</a></strong><br/><br/>'
80
- html_output += f'{item["transcription"]}<br/>'
81
- html_output += "</div>"
82
- html_output += "</div>"
83
 
84
  return html_output, history
85
 
@@ -87,7 +104,8 @@ def run(input_file, model_name, decoding_type, history):
87
  gr.Interface(
88
  run,
89
  inputs=[
90
- gr.inputs.Audio(source="microphone", type="filepath", label="Record something..."),
 
91
  gr.inputs.Radio(label="Model", choices=MODELS),
92
  gr.inputs.Radio(label="Decoding type", choices=["Standard", "Guided by Language Model"]),
93
  "state"
@@ -106,5 +124,8 @@ gr.Interface(
106
  """,
107
  allow_screenshot=False,
108
  allow_flagging="never",
109
- theme="huggingface"
 
 
 
110
  ).launch(enable_queue=True)
 
25
  MODELS = sorted(DICT_MODELS.keys())
26
  CACHED_MODELS_BY_ID = {}
27
 
28
+ def build_html(history):
29
+ html_output = "<div class='result'>"
30
+ for item in history:
31
+ if item["error_message"] is not None:
32
+ html_output += f"<div class='result_item result_item_error'>{item['error_message']}</div>"
33
+ else:
34
+ url_suffix = " + Guided by Language Model" if item["decoding_type"] == "Guided by Language Model" else ""
35
+ html_output += "<div class='result_item result_item_success'>"
36
+ html_output += f'<strong><a target="_blank" href="https://huggingface.co/{item["model_id"]}">{item["model_id"]}{url_suffix}</a></strong><br/><br/>'
37
+ html_output += f'{item["transcription"]}<br/>'
38
+ html_output += "</div>"
39
+ html_output += "</div>"
40
+ return html_output
41
 
42
+ def run(uploaded_file, input_file, model_name, decoding_type, history):
43
+
44
  model = DICT_MODELS.get(model_name)
45
+ history = history or []
46
+
47
+ if uploaded_file is None and input_file is None:
 
 
 
48
  history.append({
49
+ "model_id": model["model_id"],
50
+ "decoding_type": decoding_type,
51
+ "transcription": "",
52
+ "error_message": "No input provided."
53
  })
54
  else:
55
 
56
+ if input_file is None:
57
+ input_file = uploaded_file
 
 
 
58
 
59
+ logger.info(f"Running ASR {model_name}-{decoding_type} for {input_file}")
60
+
61
+ history = history or []
62
+
63
+ if model is None:
64
+ history.append({
65
+ "error_message": f"Model size {model_size} not found for {language} language :("
66
+ })
67
+ elif decoding_type == "Guided by Language Model" and not model["has_lm"]:
68
+ history.append({
69
+ "error_message": f"LM not available for {language} language :("
70
+ })
71
  else:
 
 
 
72
 
73
+ # model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
74
+ model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
75
+ if model_instance is None:
76
+ model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
77
+ CACHED_MODELS_BY_ID[model["model_id"]] = model_instance
78
 
79
+ if decoding_type == "Guided by Language Model":
80
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
81
+ asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
82
+ feature_extractor=processor.feature_extractor, decoder=processor.decoder)
83
+ else:
84
+ processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
85
+ asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
86
+ feature_extractor=processor.feature_extractor, decoder=None)
87
 
88
+ transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]
 
 
 
 
 
89
 
90
+ logger.info(f"Transcription for {input_file}: {transcription}")
91
+
92
+ history.append({
93
+ "model_id": model["model_id"],
94
+ "decoding_type": decoding_type,
95
+ "transcription": transcription,
96
+ "error_message": None
97
+ })
98
+
99
+ html_output = build_html(history)
 
100
 
101
  return html_output, history
102
 
 
104
  gr.Interface(
105
  run,
106
  inputs=[
107
+ gr.inputs.Audio(source="upload", type='filepath', optional=True),
108
+ gr.inputs.Audio(source="microphone", type="filepath", label="Record something...", optional=True),
109
  gr.inputs.Radio(label="Model", choices=MODELS),
110
  gr.inputs.Radio(label="Decoding type", choices=["Standard", "Guided by Language Model"]),
111
  "state"
 
124
  """,
125
  allow_screenshot=False,
126
  allow_flagging="never",
127
+ theme="huggingface",
128
+ examples = [
129
+ ['demo_example_1.mp3', 'demo_example_1.mp3', 'robust-300m', 'Guided by Language Model']
130
+ ]
131
  ).launch(enable_queue=True)
demo_example_1.mp3 ADDED
Binary file (121 kB). View file
 
gradio_queue.db CHANGED
Binary files a/gradio_queue.db and b/gradio_queue.db differ