aadnk commited on
Commit
f55c594
1 Parent(s): 764bdf1

Adding support for word timestamps

Browse files
Files changed (7) hide show
  1. app.py +28 -12
  2. cli.py +14 -2
  3. config.json5 +10 -1
  4. src/config.py +11 -1
  5. src/utils.py +117 -8
  6. src/vad.py +8 -0
  7. src/whisper/whisperContainer.py +3 -2
app.py CHANGED
@@ -100,13 +100,17 @@ class WhisperTranscriber:
100
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
101
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
102
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
103
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
 
 
 
104
 
105
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
106
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
107
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
108
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
109
- compression_ratio_threshold, logprob_threshold, no_speech_threshold)
 
110
 
111
  # Entry function for the full tab with progress
112
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
@@ -114,6 +118,9 @@ class WhisperTranscriber:
114
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
115
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
116
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
 
 
 
117
  progress=gr.Progress()):
118
 
119
  # Handle temperature_increment_on_fallback
@@ -128,13 +135,15 @@ class WhisperTranscriber:
128
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
129
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
130
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
 
131
  progress=progress)
132
 
133
  def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
134
- vadOptions: VadOptions, progress: gr.Progress = None, **decodeOptions: dict):
 
135
  try:
136
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
137
-
138
  try:
139
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
140
  selectedModel = modelName if modelName is not None else "base"
@@ -185,7 +194,7 @@ class WhisperTranscriber:
185
  # Update progress
186
  current_progress += source_audio_duration
187
 
188
- source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
189
 
190
  if len(sources) > 1:
191
  # Add new line separators
@@ -359,7 +368,7 @@ class WhisperTranscriber:
359
 
360
  return config
361
 
362
- def write_result(self, result: dict, source_name: str, output_dir: str):
363
  if not os.path.exists(output_dir):
364
  os.makedirs(output_dir)
365
 
@@ -368,8 +377,8 @@ class WhisperTranscriber:
368
  languageMaxLineWidth = self.__get_max_line_width(language)
369
 
370
  print("Max line width " + str(languageMaxLineWidth))
371
- vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
372
- srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
373
 
374
  output_files = []
375
  output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
@@ -394,13 +403,13 @@ class WhisperTranscriber:
394
  # 80 latin characters should fit on a 1080p/720p screen
395
  return 80
396
 
397
- def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
398
  segmentStream = StringIO()
399
 
400
  if format == 'vtt':
401
- write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
402
  elif format == 'srt':
403
- write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
404
  else:
405
  raise Exception("Unknown format " + format)
406
 
@@ -501,7 +510,14 @@ def create_ui(app_config: ApplicationConfig):
501
  gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
502
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
503
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
504
- gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
 
 
 
 
 
 
 
505
  ], outputs=[
506
  gr.File(label="Download"),
507
  gr.Text(label="Transcription"),
100
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
101
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
102
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
103
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
104
+ # Word timestamps
105
+ word_timestamps: bool, prepend_punctuations: str,
106
+ append_punctuations: str, highlight_words: bool = False):
107
 
108
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
109
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
110
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
111
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
112
+ compression_ratio_threshold, logprob_threshold, no_speech_threshold,
113
+ word_timestamps, prepend_punctuations, append_punctuations, highlight_words)
114
 
115
  # Entry function for the full tab with progress
116
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
118
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
119
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
120
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
121
+ # Word timestamps
122
+ word_timestamps: bool, prepend_punctuations: str,
123
+ append_punctuations: str, highlight_words: bool = False,
124
  progress=gr.Progress()):
125
 
126
  # Handle temperature_increment_on_fallback
135
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
136
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
137
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
138
+ word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
139
  progress=progress)
140
 
141
  def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
142
+ vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
143
+ **decodeOptions: dict):
144
  try:
145
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
146
+
147
  try:
148
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
149
  selectedModel = modelName if modelName is not None else "base"
194
  # Update progress
195
  current_progress += source_audio_duration
196
 
197
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
198
 
199
  if len(sources) > 1:
200
  # Add new line separators
368
 
369
  return config
370
 
371
+ def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
372
  if not os.path.exists(output_dir):
373
  os.makedirs(output_dir)
374
 
377
  languageMaxLineWidth = self.__get_max_line_width(language)
378
 
379
  print("Max line width " + str(languageMaxLineWidth))
380
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
381
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
382
 
383
  output_files = []
384
  output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
403
  # 80 latin characters should fit on a 1080p/720p screen
404
  return 80
405
 
406
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
407
  segmentStream = StringIO()
408
 
409
  if format == 'vtt':
410
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
411
  elif format == 'srt':
412
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
413
  else:
414
  raise Exception("Unknown format " + format)
415
 
510
  gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
511
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
512
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
513
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
514
+
515
+ # Word timestamps
516
+ gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
517
+ gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
518
+ gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
519
+ gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
520
+
521
  ], outputs=[
522
  gr.File(label="Download"),
523
  gr.Text(label="Transcription"),
cli.py CHANGED
@@ -95,6 +95,17 @@ def cli():
95
  parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
96
  help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  args = parser.parse_args().__dict__
99
  model_name: str = args.pop("model")
100
  model_dir: str = args.pop("model_dir")
@@ -126,6 +137,7 @@ def cli():
126
  auto_parallel = args.pop("auto_parallel")
127
 
128
  compute_type = args.pop("compute_type")
 
129
 
130
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
131
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
@@ -133,7 +145,7 @@ def cli():
133
 
134
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
135
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
136
-
137
  if (transcriber._has_parallel_devices()):
138
  print("Using parallel devices:", transcriber.parallel_device_list)
139
 
@@ -158,7 +170,7 @@ def cli():
158
 
159
  result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
160
 
161
- transcriber.write_result(result, source_name, output_dir)
162
 
163
  transcriber.close()
164
 
95
  parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
96
  help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
97
 
98
+ parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
99
+ help="(experimental) extract word-level timestamps and refine the results based on them")
100
+ parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
101
+ help="if word_timestamps is True, merge these punctuation symbols with the next word")
102
+ parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
103
+ help="if word_timestamps is True, merge these punctuation symbols with the previous word")
104
+ parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
105
+ help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
106
+ parser.add_argument("--threads", type=optional_int, default=0,
107
+ help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
108
+
109
  args = parser.parse_args().__dict__
110
  model_name: str = args.pop("model")
111
  model_dir: str = args.pop("model_dir")
137
  auto_parallel = args.pop("auto_parallel")
138
 
139
  compute_type = args.pop("compute_type")
140
+ highlight_words = args.pop("highlight_words")
141
 
142
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
143
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
145
 
146
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
147
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
148
+
149
  if (transcriber._has_parallel_devices()):
150
  print("Using parallel devices:", transcriber.parallel_device_list)
151
 
170
 
171
  result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
172
 
173
+ transcriber.write_result(result, source_name, output_dir, highlight_words)
174
 
175
  transcriber.close()
176
 
config.json5 CHANGED
@@ -128,5 +128,14 @@
128
  // If the average log probability is lower than this value, treat the decoding as failed
129
  "logprob_threshold": -1.0,
130
  // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
131
- "no_speech_threshold": 0.6
 
 
 
 
 
 
 
 
 
132
  }
128
  // If the average log probability is lower than this value, treat the decoding as failed
129
  "logprob_threshold": -1.0,
130
  // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
131
+ "no_speech_threshold": 0.6,
132
+
133
+ // (experimental) extract word-level timestamps and refine the results based on them
134
+ "word_timestamps": false,
135
+ // if word_timestamps is True, merge these punctuation symbols with the next word
136
+ "prepend_punctuations": "\"\'“¿([{-",
137
+ // if word_timestamps is True, merge these punctuation symbols with the previous word
138
+ "append_punctuations": "\"\'.。,,!!??::”)]}、",
139
+ // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
140
+ "highlight_words": false,
141
  }
src/config.py CHANGED
@@ -58,7 +58,11 @@ class ApplicationConfig:
58
  condition_on_previous_text: bool = True, fp16: bool = True,
59
  compute_type: str = "float16",
60
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
61
- logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
 
 
 
 
62
 
63
  self.models = models
64
 
@@ -104,6 +108,12 @@ class ApplicationConfig:
104
  self.logprob_threshold = logprob_threshold
105
  self.no_speech_threshold = no_speech_threshold
106
 
 
 
 
 
 
 
107
  def get_model_names(self):
108
  return [ x.name for x in self.models ]
109
 
58
  condition_on_previous_text: bool = True, fp16: bool = True,
59
  compute_type: str = "float16",
60
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
61
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
62
+ # Word timestamp settings
63
+ word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
64
+ append_punctuations: str = "\"\'.。,,!!??::”)]}、",
65
+ highlight_words: bool = False):
66
 
67
  self.models = models
68
 
108
  self.logprob_threshold = logprob_threshold
109
  self.no_speech_threshold = no_speech_threshold
110
 
111
+ # Word timestamp settings
112
+ self.word_timestamps = word_timestamps
113
+ self.prepend_punctuations = prepend_punctuations
114
+ self.append_punctuations = append_punctuations
115
+ self.highlight_words = highlight_words
116
+
117
  def get_model_names(self):
118
  return [ x.name for x in self.models ]
119
 
src/utils.py CHANGED
@@ -3,7 +3,7 @@ import unicodedata
3
  import re
4
 
5
  import zlib
6
- from typing import Iterator, TextIO
7
  import tqdm
8
 
9
  import urllib3
@@ -56,10 +56,14 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
56
  print(segment['text'].strip(), file=file, flush=True)
57
 
58
 
59
- def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
 
 
 
60
  print("WEBVTT\n", file=file)
61
- for segment in transcript:
62
- text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
 
63
 
64
  print(
65
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
@@ -68,8 +72,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
68
  flush=True,
69
  )
70
 
71
-
72
- def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
73
  """
74
  Write a transcript to a file in SRT format.
75
  Example usage:
@@ -81,8 +85,10 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
81
  with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
82
  write_srt(result["segments"], file=srt)
83
  """
84
- for i, segment in enumerate(transcript, start=1):
85
- text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
 
 
86
 
87
  # write srt lines
88
  print(
@@ -94,6 +100,109 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
94
  flush=True,
95
  )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def process_text(text: str, maxLineWidth=None):
98
  if (maxLineWidth is None or maxLineWidth < 0):
99
  return text
3
  import re
4
 
5
  import zlib
6
+ from typing import Iterator, TextIO, Union
7
  import tqdm
8
 
9
  import urllib3
56
  print(segment['text'].strip(), file=file, flush=True)
57
 
58
 
59
+ def write_vtt(transcript: Iterator[dict], file: TextIO,
60
+ maxLineWidth=None, highlight_words: bool = False):
61
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
62
+
63
  print("WEBVTT\n", file=file)
64
+
65
+ for segment in iterator:
66
+ text = segment['text'].replace('-->', '->')
67
 
68
  print(
69
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
72
  flush=True,
73
  )
74
 
75
+ def write_srt(transcript: Iterator[dict], file: TextIO,
76
+ maxLineWidth=None, highlight_words: bool = False):
77
  """
78
  Write a transcript to a file in SRT format.
79
  Example usage:
85
  with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
86
  write_srt(result["segments"], file=srt)
87
  """
88
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
89
+
90
+ for i, segment in enumerate(iterator, start=1):
91
+ text = segment['text'].replace('-->', '->')
92
 
93
  # write srt lines
94
  print(
100
  flush=True,
101
  )
102
 
103
+ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
+ for segment in transcript:
105
+ words = segment.get('words', [])
106
+
107
+ if len(words) == 0:
108
+ # Yield the segment as-is
109
+ if maxLineWidth is None or maxLineWidth < 0:
110
+ yield segment
111
+
112
+ # Yield the segment with processed text
113
+ yield {
114
+ 'start': segment['start'],
115
+ 'end': segment['end'],
116
+ 'text': process_text(segment['text'].strip(), maxLineWidth)
117
+ }
118
+
119
+ subtitle_start = segment['start']
120
+ subtitle_end = segment['end']
121
+
122
+ text_words = [ this_word["word"] for this_word in words ]
123
+ subtitle_text = __join_words(text_words, maxLineWidth)
124
+
125
+ # Iterate over the words in the segment
126
+ if highlight_words:
127
+ last = subtitle_start
128
+
129
+ for i, this_word in enumerate(words):
130
+ start = this_word['start']
131
+ end = this_word['end']
132
+
133
+ if last != start:
134
+ # Display the text up to this point
135
+ yield {
136
+ 'start': last,
137
+ 'end': start,
138
+ 'text': subtitle_text
139
+ }
140
+
141
+ # Display the text with the current word highlighted
142
+ yield {
143
+ 'start': start,
144
+ 'end': end,
145
+ 'text': __join_words(
146
+ [
147
+ {
148
+ "word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
149
+ if j == i
150
+ else word,
151
+ # The HTML tags <u> and </u> are not displayed,
152
+ # # so they should not be counted in the word length
153
+ "length": len(word)
154
+ } for j, word in enumerate(text_words)
155
+ ], maxLineWidth)
156
+ }
157
+ last = end
158
+
159
+ if last != subtitle_end:
160
+ # Display the last part of the text
161
+ yield {
162
+ 'start': last,
163
+ 'end': subtitle_end,
164
+ 'text': subtitle_text
165
+ }
166
+
167
+ # Just return the subtitle text
168
+ else:
169
+ yield {
170
+ 'start': subtitle_start,
171
+ 'end': subtitle_end,
172
+ 'text': subtitle_text
173
+ }
174
+
175
+ def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
176
+ if maxLineWidth is None or maxLineWidth < 0:
177
+ return " ".join(words)
178
+
179
+ lines = []
180
+ current_line = ""
181
+ current_length = 0
182
+
183
+ for entry in words:
184
+ # Either accept a string or a dict with a 'word' and 'length' field
185
+ if isinstance(entry, dict):
186
+ word = entry['word']
187
+ word_length = entry['length']
188
+ else:
189
+ word = entry
190
+ word_length = len(word)
191
+
192
+ if current_length > 0 and current_length + word_length > maxLineWidth:
193
+ lines.append(current_line)
194
+ current_line = ""
195
+ current_length = 0
196
+
197
+ current_length += word_length
198
+ # The word will be prefixed with a space by Whisper, so we don't need to add one here
199
+ current_line += word
200
+
201
+ if len(current_line) > 0:
202
+ lines.append(current_line)
203
+
204
+ return "\n".join(lines)
205
+
206
  def process_text(text: str, maxLineWidth=None):
207
  if (maxLineWidth is None or maxLineWidth < 0):
208
  return text
src/vad.py CHANGED
@@ -404,6 +404,14 @@ class AbstractTranscription(ABC):
404
  # Add to start and end
405
  new_segment['start'] = segment_start + adjust_seconds
406
  new_segment['end'] = segment_end + adjust_seconds
 
 
 
 
 
 
 
 
407
  result.append(new_segment)
408
  return result
409
 
404
  # Add to start and end
405
  new_segment['start'] = segment_start + adjust_seconds
406
  new_segment['end'] = segment_end + adjust_seconds
407
+
408
+ # Handle words
409
+ if ('words' in new_segment):
410
+ for word in new_segment['words']:
411
+ # Adjust start and end
412
+ word['start'] = word['start'] + adjust_seconds
413
+ word['end'] = word['end'] + adjust_seconds
414
+
415
  result.append(new_segment)
416
  return result
417
 
src/whisper/whisperContainer.py CHANGED
@@ -203,8 +203,9 @@ class WhisperCallback(AbstractWhisperCallback):
203
 
204
  initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
205
 
206
- return model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
- )
 
203
 
204
  initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
205
 
206
+ result = model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
+ )
211
+ return result