aadnk commited on
Commit
cddfab9
1 Parent(s): 3cd7d59

Add an extra interface for performing diarization

Browse files
Files changed (2) hide show
  1. app.py +80 -13
  2. cli.py +4 -4
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from datetime import datetime
2
  import json
3
  import math
4
- from typing import Iterator, Union
5
  import argparse
6
 
7
  from io import StringIO
@@ -16,14 +16,14 @@ import torch
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
  from src.diarization.diarization import Diarization
18
  from src.diarization.diarizationContainer import DiarizationContainer
 
19
  from src.hooks.progressListener import ProgressListener
20
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
21
- from src.hooks.whisperProgressHook import create_progress_listener_handle
22
  from src.languages import get_language_names
23
  from src.modelCache import ModelCache
24
  from src.prompts.jsonPromptStrategy import JsonPromptStrategy
25
  from src.prompts.prependPromptStrategy import PrependPromptStrategy
26
- from src.source import get_audio_source_collection
27
  from src.vadParallel import ParallelContext, ParallelTranscription
28
 
29
  # External programs
@@ -101,7 +101,8 @@ class WhisperTranscriber:
101
  self.diarization_kwargs = kwargs
102
 
103
  def unset_diarization(self):
104
- self.diarization.cleanup()
 
105
  self.diarization_kwargs = None
106
 
107
  # Entry function for the simple tab
@@ -185,19 +186,59 @@ class WhisperTranscriber:
185
  word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
186
  progress=progress)
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
189
  vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
 
190
  **decodeOptions: dict):
191
  try:
192
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
193
 
 
 
 
194
  try:
195
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
196
  selectedModel = modelName if modelName is not None else "base"
197
 
198
- model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
199
- model_name=selectedModel, compute_type=self.app_config.compute_type,
200
- cache=self.model_cache, models=self.app_config.models)
 
 
 
201
 
202
  # Result
203
  download = []
@@ -234,8 +275,12 @@ class WhisperTranscriber:
234
  sub_task_start=current_progress,
235
  sub_task_total=source_audio_duration)
236
 
237
- # Transcribe
238
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
 
 
 
 
239
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
240
 
241
  # Update progress
@@ -363,6 +408,10 @@ class WhisperTranscriber:
363
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
364
 
365
  # Diarization
 
 
 
 
366
  if self.diarization and self.diarization_kwargs:
367
  print("Diarizing ", audio_path)
368
  diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
@@ -373,9 +422,9 @@ class WhisperTranscriber:
373
  print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
374
 
375
  # Add speakers to result
376
- result = self.diarization.mark_speakers(diarization_result, result)
377
 
378
- return result
379
 
380
  def _create_progress_listener(self, progress: gr.Progress):
381
  if (progress is None):
@@ -449,7 +498,7 @@ class WhisperTranscriber:
449
  os.makedirs(output_dir)
450
 
451
  text = result["text"]
452
- language = result["language"]
453
  languageMaxLineWidth = self.__get_max_line_width(language)
454
 
455
  print("Max line width " + str(languageMaxLineWidth))
@@ -635,7 +684,25 @@ def create_ui(app_config: ApplicationConfig):
635
  gr.Text(label="Segments")
636
  ])
637
 
638
- demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  # Queue up the demo
641
  if is_queue_mode:
 
1
  from datetime import datetime
2
  import json
3
  import math
4
+ from typing import Callable, Iterator, Union
5
  import argparse
6
 
7
  from io import StringIO
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
  from src.diarization.diarization import Diarization
18
  from src.diarization.diarizationContainer import DiarizationContainer
19
+ from src.diarization.transcriptLoader import load_transcript
20
  from src.hooks.progressListener import ProgressListener
21
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
 
22
  from src.languages import get_language_names
23
  from src.modelCache import ModelCache
24
  from src.prompts.jsonPromptStrategy import JsonPromptStrategy
25
  from src.prompts.prependPromptStrategy import PrependPromptStrategy
26
+ from src.source import AudioSource, get_audio_source_collection
27
  from src.vadParallel import ParallelContext, ParallelTranscription
28
 
29
  # External programs
 
101
  self.diarization_kwargs = kwargs
102
 
103
  def unset_diarization(self):
104
+ if self.diarization is not None:
105
+ self.diarization.cleanup()
106
  self.diarization_kwargs = None
107
 
108
  # Entry function for the simple tab
 
186
  word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
187
  progress=progress)
188
 
189
+ # Perform diarization given a specific input audio file and whisper file
190
+ def perform_extra(self, languageName, urlData, singleFile, whisper_file: str,
191
+ highlight_words: bool = False,
192
+ diarization: bool = False, diarization_speakers: int = 2, diarization_min_speakers = 1, diarization_max_speakers = 5, progress=gr.Progress()):
193
+
194
+ if whisper_file is None:
195
+ raise ValueError("whisper_file is required")
196
+
197
+ # Set diarization
198
+ if diarization:
199
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
200
+ min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
201
+ else:
202
+ self.unset_diarization()
203
+
204
+ def custom_transcribe_file(source: AudioSource):
205
+ result = load_transcript(whisper_file.name)
206
+
207
+ # Set language if not set
208
+ if not "language" in result:
209
+ result["language"] = languageName
210
+
211
+ # Mark speakers
212
+ result = self._handle_diarization(source.source_path, result)
213
+ return result
214
+
215
+ multipleFiles = [singleFile] if singleFile else None
216
+
217
+ # Will return download, text, vtt
218
+ return self.transcribe_webui("base", "", urlData, multipleFiles, None, None, None,
219
+ progress=progress,highlight_words=highlight_words,
220
+ override_transcribe_file=custom_transcribe_file, override_max_sources=1)
221
+
222
  def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
223
  vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
224
+ override_transcribe_file: Callable[[AudioSource], dict] = None, override_max_sources = None,
225
  **decodeOptions: dict):
226
  try:
227
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
228
 
229
+ if override_max_sources is not None and len(sources) > override_max_sources:
230
+ raise ValueError("Maximum number of sources is " + str(override_max_sources) + ", but " + str(len(sources)) + " were provided")
231
+
232
  try:
233
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
234
  selectedModel = modelName if modelName is not None else "base"
235
 
236
+ if override_transcribe_file is None:
237
+ model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
238
+ model_name=selectedModel, compute_type=self.app_config.compute_type,
239
+ cache=self.model_cache, models=self.app_config.models)
240
+ else:
241
+ model = None
242
 
243
  # Result
244
  download = []
 
275
  sub_task_start=current_progress,
276
  sub_task_total=source_audio_duration)
277
 
278
+ # Transcribe using the override function if specified
279
+ if override_transcribe_file is None:
280
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
281
+ else:
282
+ result = override_transcribe_file(source)
283
+
284
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
285
 
286
  # Update progress
 
408
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
409
 
410
  # Diarization
411
+ result = self._handle_diarization(audio_path, result)
412
+ return result
413
+
414
+ def _handle_diarization(self, audio_path: str, input: dict):
415
  if self.diarization and self.diarization_kwargs:
416
  print("Diarizing ", audio_path)
417
  diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
 
422
  print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
423
 
424
  # Add speakers to result
425
+ input = self.diarization.mark_speakers(diarization_result, input)
426
 
427
+ return input
428
 
429
  def _create_progress_listener(self, progress: gr.Progress):
430
  if (progress is None):
 
498
  os.makedirs(output_dir)
499
 
500
  text = result["text"]
501
+ language = result["language"] if "language" in result else None
502
  languageMaxLineWidth = self.__get_max_line_width(language)
503
 
504
  print("Max line width " + str(languageMaxLineWidth))
 
684
  gr.Text(label="Segments")
685
  ])
686
 
687
+ perform_extra_interface = gr.Interface(fn=ui.perform_extra,
688
+ description="Perform additional processing on a given JSON or SRT file", article=ui_article, inputs=[
689
+ gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
690
+ gr.Text(label="URL (YouTube, etc.)"),
691
+ gr.File(label="Upload Audio File", file_count="single"),
692
+ gr.File(label="Upload JSON/SRT File", file_count="single"),
693
+ gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
694
+
695
+ *common_diarization_inputs(),
696
+ gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
697
+ gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
698
+
699
+ ], outputs=[
700
+ gr.File(label="Download"),
701
+ gr.Text(label="Transcription"),
702
+ gr.Text(label="Segments")
703
+ ])
704
+
705
+ demo = gr.TabbedInterface([simple_transcribe, full_transcribe, perform_extra_interface], tab_names=["Simple", "Full", "Extra"])
706
 
707
  # Queue up the demo
708
  if is_queue_mode:
cli.py CHANGED
@@ -108,12 +108,12 @@ def cli():
108
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
109
 
110
  # Diarization
111
- parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
112
  parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
  help="whether to perform speaker diarization")
114
- parser.add_argument("--diarization_num_speakers", type=int, default=None, help="Number of speakers")
115
- parser.add_argument("--diarization_min_speakers", type=int, default=None, help="Minimum number of speakers")
116
- parser.add_argument("--diarization_max_speakers", type=int, default=None, help="Maximum number of speakers")
117
 
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")
 
108
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
109
 
110
  # Diarization
111
+ parser.add_argument('--auth_token', type=str, default=app_config.auth_token, help='HuggingFace API Token (optional)')
112
  parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
  help="whether to perform speaker diarization")
114
+ parser.add_argument("--diarization_num_speakers", type=int, default=app_config.diarization_speakers, help="Number of speakers")
115
+ parser.add_argument("--diarization_min_speakers", type=int, default=app_config.diarization_min_speakers, help="Minimum number of speakers")
116
+ parser.add_argument("--diarization_max_speakers", type=int, default=app_config.diarization_max_speakers, help="Maximum number of speakers")
117
 
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")