aka7774 commited on
Commit
7935c8d
1 Parent(s): 3faeacc

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +20 -6
  2. fn.py +10 -2
  3. main.py +9 -0
app.py CHANGED
@@ -3,12 +3,26 @@ import gradio as gr
3
 
4
  fn.load_model()
5
 
6
- demo = gr.Interface(
7
- fn=fn.speech_to_text,
8
- inputs=[
9
- gr.Audio(sources="upload", type="filepath"),
10
- ],
11
- outputs=["text", "text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  if __name__ == '__main__':
14
  demo.launch()
 
3
 
4
  fn.load_model()
5
 
6
+ with gr.Blocks() as demo:
7
+ audio = gr.Audio(sources="upload", type="filepath")
8
+ model = gr.Dropdown(value='large-v3', choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"])
9
+ run_button = gr.Button(value='Run')
10
+ prompt = gr.Textbox(label='prompt')
11
+ set_button = gr.Button(value='Set Prompt')
12
+ text_only = gr.Textbox(label='output')
13
+ text_with_timestamps = gr.Textbox(label='timestamps')
14
+
15
+ run_button.click(
16
+ fn=fn.speech_to_text,
17
+ inputs=[audio, model],
18
+ outputs=[text_only, text_with_timestamps],
19
+ )
20
+
21
+ set_button.click(
22
+ fn=fn.set_prompt,
23
+ inputs=[prompt],
24
+ outputs=[],
25
+ )
26
 
27
  if __name__ == '__main__':
28
  demo.launch()
fn.py CHANGED
@@ -10,6 +10,7 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
11
  model = None
12
  pipe = None
 
13
 
14
  def load_model():
15
  global model, pipe
@@ -28,14 +29,21 @@ def load_model():
28
  device=device,
29
  )
30
 
 
 
 
 
31
  def speech_to_text(audio_file, _model_size = None):
32
- global model, pipe
33
 
34
  if not model:
35
  load_model()
36
 
37
  # run inference
38
- result = pipe(audio_file)
 
 
 
39
 
40
  try:
41
  res = json.dumps(result)
 
10
 
11
  model = None
12
  pipe = None
13
+ initial_prompt = None
14
 
15
  def load_model():
16
  global model, pipe
 
29
  device=device,
30
  )
31
 
32
+ def set_prompt(prompt):
33
+ global initial_prompt
34
+ initial_prompt = prompt
35
+
36
  def speech_to_text(audio_file, _model_size = None):
37
+ global model, pipe, initial_prompt
38
 
39
  if not model:
40
  load_model()
41
 
42
  # run inference
43
+ generate_kwargs = {}
44
+ if initial_prompt:
45
+ generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors="pt").to(device)
46
+ result = pipe(audio_file, generate_kwargs=generate_kwargs)
47
 
48
  try:
49
  res = json.dumps(result)
main.py CHANGED
@@ -40,3 +40,12 @@ async def transcribe_audio(file: UploadFile = Form(...)):
40
  return {"transcription": text_only, "text_with_timestamps": text_with_timestamps}
41
  except Exception as e:
42
  return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
40
  return {"transcription": text_only, "text_with_timestamps": text_with_timestamps}
41
  except Exception as e:
42
  return {"error": str(e)}
43
+
44
+ @app.post("/set_prompt")
45
+ async def set_prompt(prompt: str):
46
+ try:
47
+ fn.set_prompt(prompt)
48
+
49
+ return {"status": 0}
50
+ except Exception as e:
51
+ return {"error": str(e)}