Add an extra interface for performing diarization
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from datetime import datetime
|
2 |
import json
|
3 |
import math
|
4 |
-
from typing import Iterator, Union
|
5 |
import argparse
|
6 |
|
7 |
from io import StringIO
|
@@ -16,14 +16,14 @@ import torch
|
|
16 |
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
17 |
from src.diarization.diarization import Diarization
|
18 |
from src.diarization.diarizationContainer import DiarizationContainer
|
|
|
19 |
from src.hooks.progressListener import ProgressListener
|
20 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
21 |
-
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
22 |
from src.languages import get_language_names
|
23 |
from src.modelCache import ModelCache
|
24 |
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
25 |
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
26 |
-
from src.source import get_audio_source_collection
|
27 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
28 |
|
29 |
# External programs
|
@@ -101,7 +101,8 @@ class WhisperTranscriber:
|
|
101 |
self.diarization_kwargs = kwargs
|
102 |
|
103 |
def unset_diarization(self):
|
104 |
-
self.diarization
|
|
|
105 |
self.diarization_kwargs = None
|
106 |
|
107 |
# Entry function for the simple tab
|
@@ -185,19 +186,59 @@ class WhisperTranscriber:
|
|
185 |
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
186 |
progress=progress)
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
189 |
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
|
|
190 |
**decodeOptions: dict):
|
191 |
try:
|
192 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
193 |
|
|
|
|
|
|
|
194 |
try:
|
195 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
196 |
selectedModel = modelName if modelName is not None else "base"
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
201 |
|
202 |
# Result
|
203 |
download = []
|
@@ -234,8 +275,12 @@ class WhisperTranscriber:
|
|
234 |
sub_task_start=current_progress,
|
235 |
sub_task_total=source_audio_duration)
|
236 |
|
237 |
-
# Transcribe
|
238 |
-
|
|
|
|
|
|
|
|
|
239 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
240 |
|
241 |
# Update progress
|
@@ -363,6 +408,10 @@ class WhisperTranscriber:
|
|
363 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
364 |
|
365 |
# Diarization
|
|
|
|
|
|
|
|
|
366 |
if self.diarization and self.diarization_kwargs:
|
367 |
print("Diarizing ", audio_path)
|
368 |
diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
|
@@ -373,9 +422,9 @@ class WhisperTranscriber:
|
|
373 |
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
374 |
|
375 |
# Add speakers to result
|
376 |
-
|
377 |
|
378 |
-
return
|
379 |
|
380 |
def _create_progress_listener(self, progress: gr.Progress):
|
381 |
if (progress is None):
|
@@ -449,7 +498,7 @@ class WhisperTranscriber:
|
|
449 |
os.makedirs(output_dir)
|
450 |
|
451 |
text = result["text"]
|
452 |
-
language = result["language"]
|
453 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
454 |
|
455 |
print("Max line width " + str(languageMaxLineWidth))
|
@@ -635,7 +684,25 @@ def create_ui(app_config: ApplicationConfig):
|
|
635 |
gr.Text(label="Segments")
|
636 |
])
|
637 |
|
638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
|
640 |
# Queue up the demo
|
641 |
if is_queue_mode:
|
|
|
1 |
from datetime import datetime
|
2 |
import json
|
3 |
import math
|
4 |
+
from typing import Callable, Iterator, Union
|
5 |
import argparse
|
6 |
|
7 |
from io import StringIO
|
|
|
16 |
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
17 |
from src.diarization.diarization import Diarization
|
18 |
from src.diarization.diarizationContainer import DiarizationContainer
|
19 |
+
from src.diarization.transcriptLoader import load_transcript
|
20 |
from src.hooks.progressListener import ProgressListener
|
21 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
|
|
22 |
from src.languages import get_language_names
|
23 |
from src.modelCache import ModelCache
|
24 |
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
25 |
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
26 |
+
from src.source import AudioSource, get_audio_source_collection
|
27 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
28 |
|
29 |
# External programs
|
|
|
101 |
self.diarization_kwargs = kwargs
|
102 |
|
103 |
def unset_diarization(self):
|
104 |
+
if self.diarization is not None:
|
105 |
+
self.diarization.cleanup()
|
106 |
self.diarization_kwargs = None
|
107 |
|
108 |
# Entry function for the simple tab
|
|
|
186 |
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
187 |
progress=progress)
|
188 |
|
189 |
+
# Perform diarization given a specific input audio file and whisper file
|
190 |
+
def perform_extra(self, languageName, urlData, singleFile, whisper_file: str,
|
191 |
+
highlight_words: bool = False,
|
192 |
+
diarization: bool = False, diarization_speakers: int = 2, diarization_min_speakers = 1, diarization_max_speakers = 5, progress=gr.Progress()):
|
193 |
+
|
194 |
+
if whisper_file is None:
|
195 |
+
raise ValueError("whisper_file is required")
|
196 |
+
|
197 |
+
# Set diarization
|
198 |
+
if diarization:
|
199 |
+
self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
|
200 |
+
min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
|
201 |
+
else:
|
202 |
+
self.unset_diarization()
|
203 |
+
|
204 |
+
def custom_transcribe_file(source: AudioSource):
|
205 |
+
result = load_transcript(whisper_file.name)
|
206 |
+
|
207 |
+
# Set language if not set
|
208 |
+
if not "language" in result:
|
209 |
+
result["language"] = languageName
|
210 |
+
|
211 |
+
# Mark speakers
|
212 |
+
result = self._handle_diarization(source.source_path, result)
|
213 |
+
return result
|
214 |
+
|
215 |
+
multipleFiles = [singleFile] if singleFile else None
|
216 |
+
|
217 |
+
# Will return download, text, vtt
|
218 |
+
return self.transcribe_webui("base", "", urlData, multipleFiles, None, None, None,
|
219 |
+
progress=progress,highlight_words=highlight_words,
|
220 |
+
override_transcribe_file=custom_transcribe_file, override_max_sources=1)
|
221 |
+
|
222 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
223 |
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
224 |
+
override_transcribe_file: Callable[[AudioSource], dict] = None, override_max_sources = None,
|
225 |
**decodeOptions: dict):
|
226 |
try:
|
227 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
228 |
|
229 |
+
if override_max_sources is not None and len(sources) > override_max_sources:
|
230 |
+
raise ValueError("Maximum number of sources is " + str(override_max_sources) + ", but " + str(len(sources)) + " were provided")
|
231 |
+
|
232 |
try:
|
233 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
234 |
selectedModel = modelName if modelName is not None else "base"
|
235 |
|
236 |
+
if override_transcribe_file is None:
|
237 |
+
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
238 |
+
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
239 |
+
cache=self.model_cache, models=self.app_config.models)
|
240 |
+
else:
|
241 |
+
model = None
|
242 |
|
243 |
# Result
|
244 |
download = []
|
|
|
275 |
sub_task_start=current_progress,
|
276 |
sub_task_total=source_audio_duration)
|
277 |
|
278 |
+
# Transcribe using the override function if specified
|
279 |
+
if override_transcribe_file is None:
|
280 |
+
result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
281 |
+
else:
|
282 |
+
result = override_transcribe_file(source)
|
283 |
+
|
284 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
285 |
|
286 |
# Update progress
|
|
|
408 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
409 |
|
410 |
# Diarization
|
411 |
+
result = self._handle_diarization(audio_path, result)
|
412 |
+
return result
|
413 |
+
|
414 |
+
def _handle_diarization(self, audio_path: str, input: dict):
|
415 |
if self.diarization and self.diarization_kwargs:
|
416 |
print("Diarizing ", audio_path)
|
417 |
diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
|
|
|
422 |
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
423 |
|
424 |
# Add speakers to result
|
425 |
+
input = self.diarization.mark_speakers(diarization_result, input)
|
426 |
|
427 |
+
return input
|
428 |
|
429 |
def _create_progress_listener(self, progress: gr.Progress):
|
430 |
if (progress is None):
|
|
|
498 |
os.makedirs(output_dir)
|
499 |
|
500 |
text = result["text"]
|
501 |
+
language = result["language"] if "language" in result else None
|
502 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
503 |
|
504 |
print("Max line width " + str(languageMaxLineWidth))
|
|
|
684 |
gr.Text(label="Segments")
|
685 |
])
|
686 |
|
687 |
+
perform_extra_interface = gr.Interface(fn=ui.perform_extra,
|
688 |
+
description="Perform additional processing on a given JSON or SRT file", article=ui_article, inputs=[
|
689 |
+
gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
|
690 |
+
gr.Text(label="URL (YouTube, etc.)"),
|
691 |
+
gr.File(label="Upload Audio File", file_count="single"),
|
692 |
+
gr.File(label="Upload JSON/SRT File", file_count="single"),
|
693 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
|
694 |
+
|
695 |
+
*common_diarization_inputs(),
|
696 |
+
gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
|
697 |
+
gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
|
698 |
+
|
699 |
+
], outputs=[
|
700 |
+
gr.File(label="Download"),
|
701 |
+
gr.Text(label="Transcription"),
|
702 |
+
gr.Text(label="Segments")
|
703 |
+
])
|
704 |
+
|
705 |
+
demo = gr.TabbedInterface([simple_transcribe, full_transcribe, perform_extra_interface], tab_names=["Simple", "Full", "Extra"])
|
706 |
|
707 |
# Queue up the demo
|
708 |
if is_queue_mode:
|
cli.py
CHANGED
@@ -108,12 +108,12 @@ def cli():
|
|
108 |
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
109 |
|
110 |
# Diarization
|
111 |
-
parser.add_argument('--auth_token', type=str, default=
|
112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
113 |
help="whether to perform speaker diarization")
|
114 |
-
parser.add_argument("--diarization_num_speakers", type=int, default=
|
115 |
-
parser.add_argument("--diarization_min_speakers", type=int, default=
|
116 |
-
parser.add_argument("--diarization_max_speakers", type=int, default=
|
117 |
|
118 |
args = parser.parse_args().__dict__
|
119 |
model_name: str = args.pop("model")
|
|
|
108 |
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
109 |
|
110 |
# Diarization
|
111 |
+
parser.add_argument('--auth_token', type=str, default=app_config.auth_token, help='HuggingFace API Token (optional)')
|
112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
113 |
help="whether to perform speaker diarization")
|
114 |
+
parser.add_argument("--diarization_num_speakers", type=int, default=app_config.diarization_speakers, help="Number of speakers")
|
115 |
+
parser.add_argument("--diarization_min_speakers", type=int, default=app_config.diarization_min_speakers, help="Minimum number of speakers")
|
116 |
+
parser.add_argument("--diarization_max_speakers", type=int, default=app_config.diarization_max_speakers, help="Maximum number of speakers")
|
117 |
|
118 |
args = parser.parse_args().__dict__
|
119 |
model_name: str = args.pop("model")
|