asahi417 commited on
Commit
c962a1e
1 Parent(s): 0f5d4d0

add punctuator and timestamped output

Browse files
Files changed (2) hide show
  1. app.py +54 -18
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,20 +1,28 @@
 
 
 
1
  import re
 
 
2
 
3
  import torch
4
  import gradio as gr
5
  import yt_dlp as youtube_dl
6
  from transformers import pipeline
7
  from transformers.pipelines.audio_utils import ffmpeg_read
 
8
 
9
- import tempfile
10
- import os
11
 
 
12
  MODEL_NAME = "kotoba-tech/kotoba-whisper-v1.0"
13
  BATCH_SIZE = 16
14
  CHUNK_LENGTH_S = 15
15
  FILE_LIMIT_MB = 1000
16
  YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
 
17
 
 
 
18
  if torch.cuda.is_available():
19
  torch_dtype = torch.bfloat16
20
  device = "cuda:0"
@@ -24,6 +32,7 @@ else:
24
  device = "cpu"
25
  model_kwargs = {}
26
 
 
27
  pipe = pipeline(
28
  task="automatic-speech-recognition",
29
  model=MODEL_NAME,
@@ -35,21 +44,52 @@ pipe = pipeline(
35
  )
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- def transcribe(inputs, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if inputs is None:
41
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
42
- generate_kwargs = {"language": "japanese", "task": "transcribe"}
43
- prompt = "。" if not prompt else prompt
44
- generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
45
- text = pipe(inputs, generate_kwargs=generate_kwargs)['text']
46
- # currently the pipeline for ASR appends the prompt at the beginning of the transcription, so remove it
47
- return re.sub(rf"\A\s*{prompt}\s*", "", text)
48
 
49
  def _return_yt_html_embed(yt_url):
50
  video_id = yt_url.split("?v=")[-1]
51
  return f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe> </center>'
52
 
 
53
  def download_yt_audio(yt_url, filename):
54
  info_loader = youtube_dl.YoutubeDL()
55
  try:
@@ -76,7 +116,7 @@ def download_yt_audio(yt_url, filename):
76
  raise gr.Error(str(err))
77
 
78
 
79
- def yt_transcribe(yt_url, prompt, max_filesize=75.0):
80
  html_embed_str = _return_yt_html_embed(yt_url)
81
  with tempfile.TemporaryDirectory() as tmpdirname:
82
  filepath = os.path.join(tmpdirname, "video.mp4")
@@ -85,12 +125,8 @@ def yt_transcribe(yt_url, prompt, max_filesize=75.0):
85
  inputs = f.read()
86
  inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
87
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
88
- generate_kwargs = {"language": "japanese", "task": "transcribe"}
89
- prompt = "。" if not prompt else prompt
90
- generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
91
- text = pipe(inputs, generate_kwargs=generate_kwargs)['text']
92
- # currently the pipeline for ASR appends the prompt at the beginning of the transcription, so remove it
93
- return html_embed_str, re.sub(rf"\A\s*{prompt}\s*", "", text)
94
 
95
 
96
  demo = gr.Blocks()
@@ -100,7 +136,7 @@ mf_transcribe = gr.Interface(
100
  gr.inputs.Audio(source="microphone", type="filepath", optional=True),
101
  gr.inputs.Textbox(lines=1, placeholder="Prompt", optional=True)
102
  ],
103
- outputs="text",
104
  layout="horizontal",
105
  theme="huggingface",
106
  title=f"Transcribe Audio with {os.path.basename(MODEL_NAME)}",
@@ -114,7 +150,7 @@ file_transcribe = gr.Interface(
114
  gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
115
  gr.inputs.Textbox(lines=1, placeholder="Prompt", optional=True)
116
  ],
117
- outputs="text",
118
  layout="horizontal",
119
  theme="huggingface",
120
  title=f"Transcribe Audio with {os.path.basename(MODEL_NAME)}",
 
1
+ import os
2
+ import time
3
+ import tempfile
4
  import re
5
+ from math import floor
6
+ from typing import Optional
7
 
8
  import torch
9
  import gradio as gr
10
  import yt_dlp as youtube_dl
11
  from transformers import pipeline
12
  from transformers.pipelines.audio_utils import ffmpeg_read
13
+ from punctuators.models import PunctCapSegModelONNX
14
 
 
 
15
 
16
+ # configuration
17
  MODEL_NAME = "kotoba-tech/kotoba-whisper-v1.0"
18
  BATCH_SIZE = 16
19
  CHUNK_LENGTH_S = 15
20
  FILE_LIMIT_MB = 1000
21
  YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
22
+ PUNCTUATOR = PunctCapSegModelONNX.from_pretrained("pcs_47lang")
23
 
24
+
25
+ # device setting
26
  if torch.cuda.is_available():
27
  torch_dtype = torch.bfloat16
28
  device = "cuda:0"
 
32
  device = "cpu"
33
  model_kwargs = {}
34
 
35
+ # define the pipeline
36
  pipe = pipeline(
37
  task="automatic-speech-recognition",
38
  model=MODEL_NAME,
 
44
  )
45
 
46
 
47
+ def format_time(start: Optional[float], end: Optional[float]):
48
+
49
+ def _format_time(seconds: Optional[float]):
50
+ if seconds is None:
51
+ return "complete "
52
+ minutes = floor(seconds / 60)
53
+ hours = floor(seconds / 3600)
54
+ seconds = seconds - hours * 3600 - minutes * 60
55
+ m_seconds = floor(round(seconds - floor(seconds), 3) * 10 ** 3)
56
+ seconds = floor(seconds)
57
+ return f'{hours:02}:{minutes:02}:{seconds:02}.{m_seconds:03}'
58
+
59
+ return f"[{_format_time(start)}-> {_format_time(end)}]:"
60
 
61
+
62
+ def get_prediction(inputs, prompt: Optional[str], punctuate_text: bool = True):
63
+ generate_kwargs = {"language": "japanese", "task": "transcribe"}
64
+ if prompt:
65
+ generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
66
+ prediction = pipe(inputs, return_timestamps=True, generate_kwargs=generate_kwargs)
67
+ if punctuate_text:
68
+ text_edit = PUNCTUATOR.infer([c['text'] for c in prediction['chunks']])
69
+ prediction['chunks'] = [
70
+ {
71
+ 'timestamp': c['timestamp'],
72
+ 'text': "".join(e) if 'unk' not in "".join(e).lower() else c['text']
73
+ } for c, e in zip(prediction['chunks'], text_edit)
74
+ ]
75
+ text = "".join([c['text'] for c in prediction['chunks']])
76
+ text_timestamped = "\n".join([
77
+ f"{format_time(*c['timestamp'])} {c['text']}" for c in prediction['chunks']
78
+ ])
79
+ return text, text_timestamped
80
+
81
+
82
+ def transcribe(inputs, prompt, punctuate_text: bool = True):
83
  if inputs is None:
84
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
85
+ return get_prediction(inputs, prompt, punctuate_text)
86
+
 
 
 
 
87
 
88
  def _return_yt_html_embed(yt_url):
89
  video_id = yt_url.split("?v=")[-1]
90
  return f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe> </center>'
91
 
92
+
93
  def download_yt_audio(yt_url, filename):
94
  info_loader = youtube_dl.YoutubeDL()
95
  try:
 
116
  raise gr.Error(str(err))
117
 
118
 
119
+ def yt_transcribe(yt_url, prompt, punctuate_text: bool = True):
120
  html_embed_str = _return_yt_html_embed(yt_url)
121
  with tempfile.TemporaryDirectory() as tmpdirname:
122
  filepath = os.path.join(tmpdirname, "video.mp4")
 
125
  inputs = f.read()
126
  inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
127
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
128
+ text, text_timestamped = get_prediction(inputs, prompt, punctuate_text)
129
+ return html_embed_str, text, text_timestamped
 
 
 
 
130
 
131
 
132
  demo = gr.Blocks()
 
136
  gr.inputs.Audio(source="microphone", type="filepath", optional=True),
137
  gr.inputs.Textbox(lines=1, placeholder="Prompt", optional=True)
138
  ],
139
+ outputs=["text", "text"],
140
  layout="horizontal",
141
  theme="huggingface",
142
  title=f"Transcribe Audio with {os.path.basename(MODEL_NAME)}",
 
150
  gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
151
  gr.inputs.Textbox(lines=1, placeholder="Prompt", optional=True)
152
  ],
153
+ outputs=["text", "text"],
154
  layout="horizontal",
155
  theme="huggingface",
156
  title=f"Transcribe Audio with {os.path.basename(MODEL_NAME)}",
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  git+https://github.com/huggingface/transformers
2
  torch
3
  yt-dlp
 
 
1
  git+https://github.com/huggingface/transformers
2
  torch
3
  yt-dlp
4
+ punctuators