aadnk commited on
Commit
a1b1422
·
2 Parent(s): 6db3778 70d1342

Merge branch 'main' of https://huggingface.co/spaces/aadnk/whisper-webui

Browse files
app.py CHANGED
@@ -14,6 +14,8 @@ import numpy as np
14
  import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
 
17
  from src.hooks.progressListener import ProgressListener
18
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
19
  from src.hooks.whisperProgressHook import create_progress_listener_handle
@@ -73,6 +75,10 @@ class WhisperTranscriber:
73
  self.deleteUploadedFiles = delete_uploaded_files
74
  self.output_dir = output_dir
75
 
 
 
 
 
76
  self.app_config = app_config
77
 
78
  def set_parallel_devices(self, vad_parallel_devices: str):
@@ -86,22 +92,41 @@ class WhisperTranscriber:
86
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
87
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Entry function for the simple tab
90
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
91
  vad, vadMergeWindow, vadMaxMergeSize,
92
- word_timestamps: bool = False, highlight_words: bool = False):
 
93
  return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
94
  vad, vadMergeWindow, vadMaxMergeSize,
95
- word_timestamps, highlight_words)
 
96
 
97
  # Entry function for the simple tab progress
98
  def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
99
  vad, vadMergeWindow, vadMaxMergeSize,
100
  word_timestamps: bool = False, highlight_words: bool = False,
 
101
  progress=gr.Progress()):
102
 
103
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
104
 
 
 
 
 
 
105
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
106
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
107
 
@@ -112,14 +137,18 @@ class WhisperTranscriber:
112
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
113
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
114
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
115
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
 
 
116
 
117
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
118
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
119
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
120
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
121
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
122
- compression_ratio_threshold, logprob_threshold, no_speech_threshold)
 
 
123
 
124
  # Entry function for the full tab with progress
125
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
@@ -129,6 +158,8 @@ class WhisperTranscriber:
129
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
130
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
131
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
 
 
132
  progress=gr.Progress()):
133
 
134
  # Handle temperature_increment_on_fallback
@@ -139,6 +170,13 @@ class WhisperTranscriber:
139
 
140
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
141
 
 
 
 
 
 
 
 
142
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
143
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
144
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
@@ -322,6 +360,19 @@ class WhisperTranscriber:
322
  else:
323
  # Default VAD
324
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  return result
327
 
@@ -458,6 +509,10 @@ class WhisperTranscriber:
458
  if (self.cpu_parallel_context is not None):
459
  self.cpu_parallel_context.close()
460
 
 
 
 
 
461
 
462
  def create_ui(app_config: ApplicationConfig):
463
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
@@ -515,6 +570,17 @@ def create_ui(app_config: ApplicationConfig):
515
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
516
  ]
517
 
 
 
 
 
 
 
 
 
 
 
 
518
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
519
 
520
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
@@ -522,6 +588,7 @@ def create_ui(app_config: ApplicationConfig):
522
  *common_inputs(),
523
  *common_vad_inputs(),
524
  *common_word_timestamps_inputs(),
 
525
  ], outputs=[
526
  gr.File(label="Download"),
527
  gr.Text(label="Transcription"),
@@ -556,6 +623,11 @@ def create_ui(app_config: ApplicationConfig):
556
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
557
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
558
  gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
 
 
 
 
 
559
  ], outputs=[
560
  gr.File(label="Download"),
561
  gr.Text(label="Transcription"),
 
14
  import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
+ from src.diarization.diarization import Diarization
18
+ from src.diarization.diarizationContainer import DiarizationContainer
19
  from src.hooks.progressListener import ProgressListener
20
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
21
  from src.hooks.whisperProgressHook import create_progress_listener_handle
 
75
  self.deleteUploadedFiles = delete_uploaded_files
76
  self.output_dir = output_dir
77
 
78
+ # Support for diarization
79
+ self.diarization: DiarizationContainer = None
80
+ # Dictionary with parameters to pass to diarization.run - if None, diarization is not enabled
81
+ self.diarization_kwargs = None
82
  self.app_config = app_config
83
 
84
  def set_parallel_devices(self, vad_parallel_devices: str):
 
92
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
93
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
94
 
95
+ def set_diarization(self, auth_token: str, enable_daemon_process: bool = True, **kwargs):
96
+ if self.diarization is None:
97
+ self.diarization = DiarizationContainer(auth_token=auth_token, enable_daemon_process=enable_daemon_process,
98
+ auto_cleanup_timeout_seconds=self.vad_process_timeout, cache=self.model_cache)
99
+ # Set parameters
100
+ self.diarization_kwargs = kwargs
101
+
102
+ def unset_diarization(self):
103
+ self.diarization.cleanup()
104
+ self.diarization_kwargs = None
105
+
106
  # Entry function for the simple tab
107
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
108
  vad, vadMergeWindow, vadMaxMergeSize,
109
+ word_timestamps: bool = False, highlight_words: bool = False,
110
+ diarization: bool = False, diarization_speakers: int = 2):
111
  return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
112
  vad, vadMergeWindow, vadMaxMergeSize,
113
+ word_timestamps, highlight_words,
114
+ diarization, diarization_speakers)
115
 
116
  # Entry function for the simple tab progress
117
  def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
118
  vad, vadMergeWindow, vadMaxMergeSize,
119
  word_timestamps: bool = False, highlight_words: bool = False,
120
+ diarization: bool = False, diarization_speakers: int = 2,
121
  progress=gr.Progress()):
122
 
123
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
124
 
125
+ if diarization:
126
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers)
127
+ else:
128
+ self.unset_diarization()
129
+
130
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
131
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
132
 
 
137
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
138
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
139
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
140
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
141
+ diarization: bool = False, diarization_speakers: int = 2,
142
+ diarization_min_speakers = 1, diarization_max_speakers = 5):
143
 
144
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
145
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
146
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
147
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
148
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
149
+ compression_ratio_threshold, logprob_threshold, no_speech_threshold,
150
+ diarization, diarization_speakers,
151
+ diarization_min_speakers, diarization_max_speakers)
152
 
153
  # Entry function for the full tab with progress
154
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
 
158
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
159
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
160
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
161
+ diarization: bool = False, diarization_speakers: int = 2,
162
+ diarization_min_speakers = 1, diarization_max_speakers = 5,
163
  progress=gr.Progress()):
164
 
165
  # Handle temperature_increment_on_fallback
 
170
 
171
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
172
 
173
+ # Set diarization
174
+ if diarization:
175
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
176
+ min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
177
+ else:
178
+ self.unset_diarization()
179
+
180
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
181
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
182
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
 
360
  else:
361
  # Default VAD
362
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
363
+
364
+ # Diarization
365
+ if self.diarization and self.diarization_kwargs:
366
+ print("Diarizing ", audio_path)
367
+ diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
368
+
369
+ # Print result
370
+ print("Diarization result: ")
371
+ for entry in diarization_result:
372
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
373
+
374
+ # Add speakers to result
375
+ result = self.diarization.mark_speakers(diarization_result, result)
376
 
377
  return result
378
 
 
509
  if (self.cpu_parallel_context is not None):
510
  self.cpu_parallel_context.close()
511
 
512
+ # Cleanup diarization
513
+ if (self.diarization is not None):
514
+ self.diarization.cleanup()
515
+ self.diarization = None
516
 
517
  def create_ui(app_config: ApplicationConfig):
518
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
 
570
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
571
  ]
572
 
573
+ has_diarization_libs = Diarization.has_libraries()
574
+
575
+ if not has_diarization_libs:
576
+ print("Diarization libraries not found - disabling diarization")
577
+ app_config.diarization = False
578
+
579
+ common_diarization_inputs = lambda : [
580
+ gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs),
581
+ gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs)
582
+ ]
583
+
584
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
585
 
586
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
 
588
  *common_inputs(),
589
  *common_vad_inputs(),
590
  *common_word_timestamps_inputs(),
591
+ *common_diarization_inputs(),
592
  ], outputs=[
593
  gr.File(label="Download"),
594
  gr.Text(label="Transcription"),
 
623
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
624
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
625
  gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
626
+
627
+ *common_diarization_inputs(),
628
+ gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
629
+ gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
630
+
631
  ], outputs=[
632
  gr.File(label="Download"),
633
  gr.Text(label="Transcription"),
cli.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
@@ -106,6 +107,14 @@ def cli():
106
  parser.add_argument("--threads", type=optional_int, default=0,
107
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
108
 
 
 
 
 
 
 
 
 
109
  args = parser.parse_args().__dict__
110
  model_name: str = args.pop("model")
111
  model_dir: str = args.pop("model_dir")
@@ -142,10 +151,19 @@ def cli():
142
  compute_type = args.pop("compute_type")
143
  highlight_words = args.pop("highlight_words")
144
 
 
 
 
 
 
 
145
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
146
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
147
  transcriber.set_auto_parallel(auto_parallel)
148
 
 
 
 
149
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
150
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
151
 
 
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
+ from src.diarization.diarization import Diarization
12
  from src.download import download_url
13
  from src.languages import get_language_names
14
 
 
107
  parser.add_argument("--threads", type=optional_int, default=0,
108
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
109
 
110
+ # Diarization
111
+ parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
112
+ parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
+ help="whether to perform speaker diarization")
114
+ parser.add_argument("--diarization_num_speakers", type=int, default=None, help="Number of speakers")
115
+ parser.add_argument("--diarization_min_speakers", type=int, default=None, help="Minimum number of speakers")
116
+ parser.add_argument("--diarization_max_speakers", type=int, default=None, help="Maximum number of speakers")
117
+
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")
120
  model_dir: str = args.pop("model_dir")
 
151
  compute_type = args.pop("compute_type")
152
  highlight_words = args.pop("highlight_words")
153
 
154
+ auth_token = args.pop("auth_token")
155
+ diarization = args.pop("diarization")
156
+ num_speakers = args.pop("diarization_num_speakers")
157
+ min_speakers = args.pop("diarization_min_speakers")
158
+ max_speakers = args.pop("diarization_max_speakers")
159
+
160
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
161
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
162
  transcriber.set_auto_parallel(auto_parallel)
163
 
164
+ if diarization:
165
+ transcriber.set_diarization(auth_token=auth_token, enable_daemon_process=False, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
166
+
167
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
168
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
169
 
config.json5 CHANGED
@@ -140,4 +140,15 @@
140
  "append_punctuations": "\"\'.。,,!!??::”)]}、",
141
  // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
142
  "highlight_words": false,
 
 
 
 
 
 
 
 
 
 
 
143
  }
 
140
  "append_punctuations": "\"\'.。,,!!??::”)]}、",
141
  // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
142
  "highlight_words": false,
143
+
144
+ // Diarization settings
145
+ "auth_token": null,
146
+ // Whether to perform speaker diarization
147
+ "diarization": false,
148
+ // The number of speakers to detect
149
+ "diarization_speakers": 2,
150
+ // The minimum number of speakers to detect
151
+ "diarization_min_speakers": 1,
152
+ // The maximum number of speakers to detect
153
+ "diarization_max_speakers": 5,
154
  }
docs/options.md CHANGED
@@ -80,6 +80,17 @@ number of seconds after the line has finished. For instance, if a line ends at 1
80
  Note that detected lines in gaps between speech sections will not be included in the prompt
81
  (if silero-vad or silero-vad-expand-into-gaps) is used.
82
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Command Line Options
84
 
85
  Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
@@ -132,3 +143,11 @@ If the average log probability is lower than this value, treat the decoding as f
132
 
133
  ## No speech threshold
134
  If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
 
 
 
 
 
 
 
 
 
80
  Note that detected lines in gaps between speech sections will not be included in the prompt
81
  (if silero-vad or silero-vad-expand-into-gaps) is used.
82
 
83
+ ## Diarization
84
+
85
+ If checked, Pyannote will be used to detect speakers in the audio, and label them as (SPEAKER 00), (SPEAKER 01), etc.
86
+
87
+ This requires a HuggingFace API key to function, which can be supplied with the `--auth_token` command line option for the CLI,
88
+ set in the `config.json5` file for the GUI, or provided via the `HK_AUTH_TOKEN` environment variable.
89
+
90
+ ## Diarization - Speakers
91
+
92
+ The number of speakers to detect. If set to 0, Pyannote will attempt to detect the number of speakers automatically.
93
+
94
  # Command Line Options
95
 
96
  Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
 
143
 
144
  ## No speech threshold
145
  If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
146
+
147
+ ## Diarization - Min Speakers
148
+
149
+ The minimum number of speakers for Pyannote to detect.
150
+
151
+ ## Diarization - Max Speakers
152
+
153
+ The maximum number of speakers for Pyannote to detect.
requirements-fasterWhisper.txt CHANGED
@@ -6,4 +6,10 @@ yt-dlp
6
  json5
7
  torch
8
  torchaudio
9
- more_itertools
 
 
 
 
 
 
 
6
  json5
7
  torch
8
  torchaudio
9
+ more_itertools
10
+
11
+ # Needed by diarization
12
+ intervaltree
13
+ srt
14
+ torch
15
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
requirements-whisper.txt CHANGED
@@ -6,4 +6,10 @@ gradio==3.38.0
6
  yt-dlp
7
  torchaudio
8
  altair
9
- json5
 
 
 
 
 
 
 
6
  yt-dlp
7
  torchaudio
8
  altair
9
+ json5
10
+
11
+ # Needed by diarization
12
+ intervaltree
13
+ srt
14
+ torch
15
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
requirements.txt CHANGED
@@ -6,4 +6,10 @@ yt-dlp
6
  json5
7
  torch
8
  torchaudio
9
- more_itertools
 
 
 
 
 
 
 
6
  json5
7
  torch
8
  torchaudio
9
+ more_itertools
10
+
11
+ # Needed by diarization
12
+ intervaltree
13
+ srt
14
+ torch
15
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
src/config.py CHANGED
@@ -69,7 +69,10 @@ class ApplicationConfig:
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
- highlight_words: bool = False):
 
 
 
73
 
74
  self.models = models
75
 
@@ -121,6 +124,13 @@ class ApplicationConfig:
121
  self.append_punctuations = append_punctuations
122
  self.highlight_words = highlight_words
123
 
 
 
 
 
 
 
 
124
  def get_model_names(self):
125
  return [ x.name for x in self.models ]
126
 
 
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
+ highlight_words: bool = False,
73
+ # Diarization
74
+ auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
75
+ diarization_min_speakers: int = 1, diarization_max_speakers: int = 5):
76
 
77
  self.models = models
78
 
 
124
  self.append_punctuations = append_punctuations
125
  self.highlight_words = highlight_words
126
 
127
+ # Diarization settings
128
+ self.auth_token = auth_token
129
+ self.diarization = diarization
130
+ self.diarization_speakers = diarization_speakers
131
+ self.diarization_min_speakers = diarization_min_speakers
132
+ self.diarization_max_speakers = diarization_max_speakers
133
+
134
  def get_model_names(self):
135
  return [ x.name for x in self.models ]
136
 
src/diarization/diarization.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ import tempfile
7
+ from typing import TYPE_CHECKING, List
8
+ import torch
9
+
10
+ import ffmpeg
11
+
12
+ class DiarizationEntry:
13
+ def __init__(self, start, end, speaker):
14
+ self.start = start
15
+ self.end = end
16
+ self.speaker = speaker
17
+
18
+ def __repr__(self):
19
+ return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>"
20
+
21
+ def toJson(self):
22
+ return {
23
+ "start": self.start,
24
+ "end": self.end,
25
+ "speaker": self.speaker
26
+ }
27
+
28
+ class Diarization:
29
+ def __init__(self, auth_token=None):
30
+ if auth_token is None:
31
+ auth_token = os.environ.get("HK_ACCESS_TOKEN")
32
+ if auth_token is None:
33
+ raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HK_ACCESS_TOKEN environment variable")
34
+
35
+ self.auth_token = auth_token
36
+ self.initialized = False
37
+ self.pipeline = None
38
+
39
+ @staticmethod
40
+ def has_libraries():
41
+ try:
42
+ import pyannote.audio
43
+ import intervaltree
44
+ return True
45
+ except ImportError:
46
+ return False
47
+
48
+ def initialize(self):
49
+ if self.initialized:
50
+ return
51
+ from pyannote.audio import Pipeline
52
+
53
+ self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", use_auth_token=self.auth_token)
54
+ self.initialized = True
55
+
56
+ # Load GPU mode if available
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ if device == "cuda":
59
+ print("Diarization - using GPU")
60
+ self.pipeline = self.pipeline.to(torch.device(0))
61
+ else:
62
+ print("Diarization - using CPU")
63
+
64
+ def run(self, audio_file, **kwargs):
65
+ self.initialize()
66
+ audio_file_obj = Path(audio_file)
67
+
68
+ # Supported file types in soundfile is WAV, FLAC, OGG and MAT
69
+ if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]:
70
+ target_file = audio_file
71
+ else:
72
+ # Create temp WAV file
73
+ target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav")
74
+ try:
75
+ ffmpeg.input(audio_file).output(target_file, ac=1).run()
76
+ except ffmpeg.Error as e:
77
+ print(f"Error occurred during audio conversion: {e.stderr}")
78
+
79
+ diarization = self.pipeline(target_file, **kwargs)
80
+
81
+ if target_file != audio_file:
82
+ # Delete temp file
83
+ os.remove(target_file)
84
+
85
+ # Yield result
86
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
87
+ yield DiarizationEntry(turn.start, turn.end, speaker)
88
+
89
+ def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
90
+ from intervaltree import IntervalTree
91
+ result = whisper_result.copy()
92
+
93
+ # Create an interval tree from the diarization results
94
+ tree = IntervalTree()
95
+ for entry in diarization_result:
96
+ tree[entry.start:entry.end] = entry
97
+
98
+ # Iterate through each segment in the Whisper JSON
99
+ for segment in result["segments"]:
100
+ segment_start = segment["start"]
101
+ segment_end = segment["end"]
102
+
103
+ # Find overlapping speakers using the interval tree
104
+ overlapping_speakers = tree[segment_start:segment_end]
105
+
106
+ # If no speakers overlap with this segment, skip it
107
+ if not overlapping_speakers:
108
+ continue
109
+
110
+ # If multiple speakers overlap with this segment, choose the one with the longest duration
111
+ longest_speaker = None
112
+ longest_duration = 0
113
+
114
+ for speaker_interval in overlapping_speakers:
115
+ overlap_start = max(speaker_interval.begin, segment_start)
116
+ overlap_end = min(speaker_interval.end, segment_end)
117
+ overlap_duration = overlap_end - overlap_start
118
+
119
+ if overlap_duration > longest_duration:
120
+ longest_speaker = speaker_interval.data.speaker
121
+ longest_duration = overlap_duration
122
+
123
+ # Add speakers
124
+ segment["longest_speaker"] = longest_speaker
125
+ segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers])
126
+
127
+ # The write_srt will use the longest_speaker if it exist, and add it to the text field
128
+
129
+ return result
130
+
131
+ def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None):
132
+ if input_file is None:
133
+ raise ValueError("input_file is required")
134
+ if file_writer is None:
135
+ raise ValueError("file_writer is required")
136
+
137
+ # Write file
138
+ if output_path is None:
139
+ effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension
140
+ else:
141
+ effective_path = output_path
142
+
143
+ with open(effective_path, 'w+', encoding="utf-8") as f:
144
+ file_writer(f)
145
+
146
+ print(f"Output saved to {effective_path}")
147
+
148
+ def main():
149
+ from src.utils import write_srt
150
+ from src.diarization.transcriptLoader import load_transcript
151
+
152
+ parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
153
+ parser.add_argument('audio_file', type=str, help='Input audio file')
154
+ parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
155
+ parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)')
156
+ parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)')
157
+ parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
158
+ parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)")
159
+ parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
160
+ parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
161
+ parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
162
+
163
+ args = parser.parse_args()
164
+
165
+ print("\nReading whisper JSON from " + args.whisper_file)
166
+
167
+ # Read whisper JSON or SRT file
168
+ whisper_result = load_transcript(args.whisper_file)
169
+
170
+ diarization = Diarization(auth_token=args.auth_token)
171
+ diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
172
+
173
+ # Print result
174
+ print("Diarization result:")
175
+ for entry in diarization_result:
176
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
177
+
178
+ marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result)
179
+
180
+ # Write output JSON to file
181
+ _write_file(args.whisper_file, args.output_json_file, ".json",
182
+ lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False))
183
+
184
+ # Write SRT
185
+ _write_file(args.whisper_file, args.output_srt_file, ".srt",
186
+ lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
187
+
188
+ if __name__ == "__main__":
189
+ main()
190
+
191
+ #test = Diarization()
192
+ #print("Initializing")
193
+ #test.initialize()
194
+
195
+ #input("Press Enter to continue...")
src/diarization/diarizationContainer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src.diarization.diarization import Diarization, DiarizationEntry
3
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
4
+ from src.vadParallel import ParallelContext
5
+
6
+ class DiarizationContainer:
7
+ def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None):
8
+ self.auth_token = auth_token
9
+ self.enable_daemon_process = enable_daemon_process
10
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
11
+ self.diarization_context: ParallelContext = None
12
+ self.cache = cache
13
+ self.model = None
14
+
15
+ def run(self, audio_file, **kwargs):
16
+ # Create parallel context if needed
17
+ if self.diarization_context is None and self.enable_daemon_process:
18
+ # Number of processes is set to 1 as we mainly use this in order to clean up GPU memory
19
+ self.diarization_context = ParallelContext(num_processes=1)
20
+
21
+ # Run directly
22
+ if self.diarization_context is None:
23
+ return self.execute(audio_file, **kwargs)
24
+
25
+ # Otherwise run in a separate process
26
+ pool = self.diarization_context.get_pool()
27
+
28
+ try:
29
+ result = pool.apply(self.execute, (audio_file,), kwargs)
30
+ return result
31
+ finally:
32
+ self.diarization_context.return_pool(pool)
33
+
34
+ def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
35
+ if self.model is not None:
36
+ return self.model.mark_speakers(diarization_result, whisper_result)
37
+
38
+ # Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
39
+ model = Diarization(self.auth_token)
40
+ return model.mark_speakers(diarization_result, whisper_result)
41
+
42
+ def get_model(self):
43
+ # Lazy load the model
44
+ if (self.model is None):
45
+ if self.cache:
46
+ print("Loading diarization model from cache")
47
+ self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token))
48
+ else:
49
+ print("Loading diarization model")
50
+ self.model = Diarization(self.auth_token)
51
+ return self.model
52
+
53
+ def execute(self, audio_file, **kwargs):
54
+ model = self.get_model()
55
+
56
+ # We must use list() here to force the iterator to run, as generators are not picklable
57
+ result = list(model.run(audio_file, **kwargs))
58
+ return result
59
+
60
+ def cleanup(self):
61
+ if self.diarization_context is not None:
62
+ self.diarization_context.close()
63
+
64
+ def __getstate__(self):
65
+ return {
66
+ "auth_token": self.auth_token,
67
+ "enable_daemon_process": self.enable_daemon_process,
68
+ "auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds
69
+ }
70
+
71
+ def __setstate__(self, state):
72
+ self.auth_token = state["auth_token"]
73
+ self.enable_daemon_process = state["enable_daemon_process"]
74
+ self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
75
+ self.diarization_context = None
76
+ self.cache = GLOBAL_MODEL_CACHE
77
+ self.model = None
src/diarization/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ intervaltree
2
+ srt
3
+ torch
4
+ ffmpeg-python==0.2.0
5
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
src/diarization/transcriptLoader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ from pathlib import Path
4
+
5
+ def load_transcript_json(transcript_file: str):
6
+ """
7
+ Parse a Whisper JSON file into a Whisper JSON object
8
+
9
+ # Parameters:
10
+ transcript_file (str): Path to the Whisper JSON file
11
+ """
12
+ with open(transcript_file, "r", encoding="utf-8") as f:
13
+ whisper_result = json.load(f)
14
+
15
+ # Format of Whisper JSON file:
16
+ # {
17
+ # "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.",
18
+ # "segments": [
19
+ # {
20
+ # "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.",
21
+ # "start": 0.0,
22
+ # "end": 10.36,
23
+ # "words": [
24
+ # {
25
+ # "start": 0.0,
26
+ # "end": 0.56,
27
+ # "word": " And",
28
+ # "probability": 0.61767578125
29
+ # },
30
+ # {
31
+ # "start": 0.56,
32
+ # "end": 0.88,
33
+ # "word": " so",
34
+ # "probability": 0.9033203125
35
+ # },
36
+ # etc.
37
+
38
+ return whisper_result
39
+
40
+
41
+ def load_transcript_srt(subtitle_file: str):
42
+ import srt
43
+
44
+ """
45
+ Parse a SRT file into a Whisper JSON object
46
+
47
+ # Parameters:
48
+ subtitle_file (str): Path to the SRT file
49
+ """
50
+ with open(subtitle_file, "r", encoding="utf-8") as f:
51
+ subs = srt.parse(f)
52
+
53
+ whisper_result = {
54
+ "text": "",
55
+ "segments": []
56
+ }
57
+
58
+ for sub in subs:
59
+ # Subtitle(index=1, start=datetime.timedelta(seconds=33, microseconds=843000), end=datetime.timedelta(seconds=38, microseconds=97000), content='地球上只有3%的水是淡水', proprietary='')
60
+ segment = {
61
+ "text": sub.content,
62
+ "start": sub.start.total_seconds(),
63
+ "end": sub.end.total_seconds(),
64
+ "words": []
65
+ }
66
+ whisper_result["segments"].append(segment)
67
+ whisper_result["text"] += sub.content
68
+
69
+ return whisper_result
70
+
71
+ def load_transcript(file: str):
72
+ # Determine file type
73
+ file_extension = Path(file).suffix.lower()
74
+
75
+ if file_extension == ".json":
76
+ return load_transcript_json(file)
77
+ elif file_extension == ".srt":
78
+ return load_transcript_srt(file)
79
+ else:
80
+ raise ValueError(f"Unsupported file type: {file_extension}")
src/utils.py CHANGED
@@ -102,17 +102,26 @@ def write_srt(transcript: Iterator[dict], file: TextIO,
102
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
- words = segment.get('words', [])
 
 
 
106
 
107
  if len(words) == 0:
108
  # Yield the segment as-is or processed
109
- if maxLineWidth is None or maxLineWidth < 0:
110
  yield segment
111
  else:
 
 
 
 
 
 
112
  yield {
113
  'start': segment['start'],
114
  'end': segment['end'],
115
- 'text': process_text(segment['text'].strip(), maxLineWidth)
116
  }
117
  # We are done
118
  continue
@@ -120,9 +129,17 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
120
  subtitle_start = segment['start']
121
  subtitle_end = segment['end']
122
 
 
 
 
 
 
 
 
 
123
  text_words = [ this_word["word"] for this_word in words ]
124
  subtitle_text = __join_words(text_words, maxLineWidth)
125
-
126
  # Iterate over the words in the segment
127
  if highlight_words:
128
  last = subtitle_start
 
102
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
+ words: list = segment.get('words', [])
106
+
107
+ # Append longest speaker ID if available
108
+ segment_longest_speaker = segment.get('longest_speaker', None)
109
 
110
  if len(words) == 0:
111
  # Yield the segment as-is or processed
112
+ if (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
113
  yield segment
114
  else:
115
+ text = segment['text'].strip()
116
+
117
+ # Prepend the longest speaker ID if available
118
+ if segment_longest_speaker is not None:
119
+ text = f"({segment_longest_speaker}) {text}"
120
+
121
  yield {
122
  'start': segment['start'],
123
  'end': segment['end'],
124
+ 'text': process_text(text, maxLineWidth)
125
  }
126
  # We are done
127
  continue
 
129
  subtitle_start = segment['start']
130
  subtitle_end = segment['end']
131
 
132
+ if segment_longest_speaker is not None:
133
+ # Add the beginning
134
+ words.insert(0, {
135
+ 'start': subtitle_start,
136
+ 'end': subtitle_start,
137
+ 'word': f"({segment_longest_speaker})"
138
+ })
139
+
140
  text_words = [ this_word["word"] for this_word in words ]
141
  subtitle_text = __join_words(text_words, maxLineWidth)
142
+
143
  # Iterate over the words in the segment
144
  if highlight_words:
145
  last = subtitle_start