asahi417 commited on
Commit
fc18a2b
1 Parent(s): da4f293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -34,10 +34,12 @@ pipe = pipeline(
34
 
35
 
36
 
37
- def transcribe(inputs):
38
  if inputs is None:
39
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
40
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
 
 
41
  return pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
42
 
43
 
@@ -71,7 +73,7 @@ def download_yt_audio(yt_url, filename):
71
  raise gr.Error(str(err))
72
 
73
 
74
- def yt_transcribe(yt_url, max_filesize=75.0):
75
  html_embed_str = _return_yt_html_embed(yt_url)
76
  with tempfile.TemporaryDirectory() as tmpdirname:
77
  filepath = os.path.join(tmpdirname, "video.mp4")
@@ -81,6 +83,8 @@ def yt_transcribe(yt_url, max_filesize=75.0):
81
  inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
82
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
83
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
 
 
84
  text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
85
  return html_embed_str, text
86
 
@@ -88,7 +92,10 @@ def yt_transcribe(yt_url, max_filesize=75.0):
88
  demo = gr.Blocks()
89
  mf_transcribe = gr.Interface(
90
  fn=transcribe,
91
- inputs=[gr.inputs.Audio(source="microphone", type="filepath", optional=True)],
 
 
 
92
  outputs="text",
93
  layout="horizontal",
94
  theme="huggingface",
@@ -99,7 +106,10 @@ mf_transcribe = gr.Interface(
99
 
100
  file_transcribe = gr.Interface(
101
  fn=transcribe,
102
- inputs=[gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file")],
 
 
 
103
  outputs="text",
104
  layout="horizontal",
105
  theme="huggingface",
@@ -109,7 +119,10 @@ file_transcribe = gr.Interface(
109
  )
110
  yt_transcribe = gr.Interface(
111
  fn=yt_transcribe,
112
- inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
 
 
 
113
  outputs=["html", "text"],
114
  layout="horizontal",
115
  theme="huggingface",
 
34
 
35
 
36
 
37
+ def transcribe(inputs, prompt):
38
  if inputs is None:
39
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
40
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
41
+ if prompt:
42
+ generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
43
  return pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
44
 
45
 
 
73
  raise gr.Error(str(err))
74
 
75
 
76
+ def yt_transcribe(yt_url, prompt, max_filesize=75.0):
77
  html_embed_str = _return_yt_html_embed(yt_url)
78
  with tempfile.TemporaryDirectory() as tmpdirname:
79
  filepath = os.path.join(tmpdirname, "video.mp4")
 
83
  inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
84
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
85
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
86
+ if prompt:
87
+ generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
88
  text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
89
  return html_embed_str, text
90
 
 
92
  demo = gr.Blocks()
93
  mf_transcribe = gr.Interface(
94
  fn=transcribe,
95
+ inputs=[
96
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True),
97
+ gr.inputs.Textbox(lines=1, placeholder="Prompt", value="")
98
+ ],
99
  outputs="text",
100
  layout="horizontal",
101
  theme="huggingface",
 
106
 
107
  file_transcribe = gr.Interface(
108
  fn=transcribe,
109
+ inputs=[
110
+ gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
111
+ gr.inputs.Textbox(lines=1, placeholder="Prompt", value="")
112
+ ],
113
  outputs="text",
114
  layout="horizontal",
115
  theme="huggingface",
 
119
  )
120
  yt_transcribe = gr.Interface(
121
  fn=yt_transcribe,
122
+ inputs=[
123
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
124
+ gr.inputs.Textbox(lines=1, placeholder="Prompt", value="")
125
+ ],
126
  outputs=["html", "text"],
127
  layout="horizontal",
128
  theme="huggingface",