aadnk commited on
Commit
18bb72f
1 Parent(s): a1da02d

Ensure GPU memory in diarization can be cleaned up

Browse files
app.py CHANGED
@@ -15,6 +15,7 @@ import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
  from src.diarization.diarization import Diarization
 
18
  from src.hooks.progressListener import ProgressListener
19
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
20
  from src.hooks.whisperProgressHook import create_progress_listener_handle
@@ -74,7 +75,10 @@ class WhisperTranscriber:
74
  self.deleteUploadedFiles = delete_uploaded_files
75
  self.output_dir = output_dir
76
 
77
- self.diarization: Diarization = None
 
 
 
78
  self.app_config = app_config
79
 
80
  def set_parallel_devices(self, vad_parallel_devices: str):
@@ -88,6 +92,17 @@ class WhisperTranscriber:
88
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
89
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
90
 
 
 
 
 
 
 
 
 
 
 
 
91
  # Entry function for the simple tab
92
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
93
  vad, vadMergeWindow, vadMaxMergeSize,
@@ -108,9 +123,9 @@ class WhisperTranscriber:
108
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
109
 
110
  if diarization:
111
- self.diarization = Diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers)
112
  else:
113
- self.diarization = None
114
 
115
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
116
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
@@ -157,10 +172,10 @@ class WhisperTranscriber:
157
 
158
  # Set diarization
159
  if diarization:
160
- self.diarization = Diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
161
- min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
162
  else:
163
- self.diarization = None
164
 
165
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
166
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
@@ -226,9 +241,9 @@ class WhisperTranscriber:
226
  current_progress += source_audio_duration
227
 
228
  # Diarization
229
- if self.diarization:
230
  print("Diarizing ", source.source_path)
231
- diarization_result = list(self.diarization.run(source.source_path))
232
 
233
  # Print result
234
  print("Diarization result: ")
@@ -494,6 +509,10 @@ class WhisperTranscriber:
494
  if (self.cpu_parallel_context is not None):
495
  self.cpu_parallel_context.close()
496
 
 
 
 
 
497
 
498
  def create_ui(app_config: ApplicationConfig):
499
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
 
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
  from src.diarization.diarization import Diarization
18
+ from src.diarization.diarizationContainer import DiarizationContainer
19
  from src.hooks.progressListener import ProgressListener
20
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
21
  from src.hooks.whisperProgressHook import create_progress_listener_handle
 
75
  self.deleteUploadedFiles = delete_uploaded_files
76
  self.output_dir = output_dir
77
 
78
+ # Support for diarization
79
+ self.diarization: DiarizationContainer = None
80
+ # Dictionary with parameters to pass to diarization.run - if None, diarization is not enabled
81
+ self.diarization_kwargs = None
82
  self.app_config = app_config
83
 
84
  def set_parallel_devices(self, vad_parallel_devices: str):
 
92
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
93
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
94
 
95
+ def set_diarization(self, auth_token: str, enable_daemon_process: bool = True, **kwargs):
96
+ if self.diarization is None:
97
+ self.diarization = DiarizationContainer(auth_token=auth_token, enable_daemon_process=enable_daemon_process,
98
+ auto_cleanup_timeout_seconds=self.vad_process_timeout, cache=self.model_cache)
99
+ # Set parameters
100
+ self.diarization_kwargs = kwargs
101
+
102
+ def unset_diarization(self):
103
+ self.diarization.cleanup()
104
+ self.diarization_kwargs = None
105
+
106
  # Entry function for the simple tab
107
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
108
  vad, vadMergeWindow, vadMaxMergeSize,
 
123
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
124
 
125
  if diarization:
126
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers)
127
  else:
128
+ self.unset_diarization()
129
 
130
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
131
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
 
172
 
173
  # Set diarization
174
  if diarization:
175
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
176
+ min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
177
  else:
178
+ self.unset_diarization()
179
 
180
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
181
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
 
241
  current_progress += source_audio_duration
242
 
243
  # Diarization
244
+ if self.diarization and self.diarization_kwargs:
245
  print("Diarizing ", source.source_path)
246
+ diarization_result = list(self.diarization.run(source.source_path, **self.diarization_kwargs))
247
 
248
  # Print result
249
  print("Diarization result: ")
 
509
  if (self.cpu_parallel_context is not None):
510
  self.cpu_parallel_context.close()
511
 
512
+ # Cleanup diarization
513
+ if (self.diarization is not None):
514
+ self.diarization.cleanup()
515
+ self.diarization = None
516
 
517
  def create_ui(app_config: ApplicationConfig):
518
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
cli.py CHANGED
@@ -162,7 +162,7 @@ def cli():
162
  transcriber.set_auto_parallel(auto_parallel)
163
 
164
  if diarization:
165
- transcriber.set_diarization(Diarization(auth_token=auth_token, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers))
166
 
167
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
168
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
 
162
  transcriber.set_auto_parallel(auto_parallel)
163
 
164
  if diarization:
165
+ transcriber.set_diarization(auth_token=auth_token, enable_daemon_process=False, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
166
 
167
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
168
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
src/diarization/diarization.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import json
3
  import os
4
  from pathlib import Path
@@ -8,9 +9,6 @@ import torch
8
 
9
  import ffmpeg
10
 
11
- from src.diarization.transcriptLoader import load_transcript
12
- from src.utils import write_srt
13
-
14
  class DiarizationEntry:
15
  def __init__(self, start, end, speaker):
16
  self.start = start
@@ -28,7 +26,7 @@ class DiarizationEntry:
28
  }
29
 
30
  class Diarization:
31
- def __init__(self, auth_token=None, **kwargs):
32
  if auth_token is None:
33
  auth_token = os.environ.get("HK_ACCESS_TOKEN")
34
  if auth_token is None:
@@ -37,7 +35,6 @@ class Diarization:
37
  self.auth_token = auth_token
38
  self.initialized = False
39
  self.pipeline = None
40
- self.pipeline_kwargs = kwargs
41
 
42
  @staticmethod
43
  def has_libraries():
@@ -54,6 +51,7 @@ class Diarization:
54
  from pyannote.audio import Pipeline
55
 
56
  self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", use_auth_token=self.auth_token)
 
57
 
58
  # Load GPU mode if available
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -63,7 +61,7 @@ class Diarization:
63
  else:
64
  print("Diarization - using CPU")
65
 
66
- def run(self, audio_file):
67
  self.initialize()
68
  audio_file_obj = Path(audio_file)
69
 
@@ -78,7 +76,7 @@ class Diarization:
78
  except ffmpeg.Error as e:
79
  print(f"Error occurred during audio conversion: {e.stderr}")
80
 
81
- diarization = self.pipeline(target_file, **self.pipeline_kwargs)
82
 
83
  if target_file != audio_file:
84
  # Delete temp file
@@ -148,6 +146,9 @@ def _write_file(input_file: str, output_path: str, output_extension: str, file_w
148
  print(f"Output saved to {effective_path}")
149
 
150
  def main():
 
 
 
151
  parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
152
  parser.add_argument('audio_file', type=str, help='Input audio file')
153
  parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
@@ -166,8 +167,8 @@ def main():
166
  # Read whisper JSON or SRT file
167
  whisper_result = load_transcript(args.whisper_file)
168
 
169
- diarization = Diarization(auth_token=args.auth_token, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers)
170
- diarization_result = list(diarization.run(args.audio_file))
171
 
172
  # Print result
173
  print("Diarization result:")
@@ -185,4 +186,10 @@ def main():
185
  lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
186
 
187
  if __name__ == "__main__":
188
- main()
 
 
 
 
 
 
 
1
  import argparse
2
+ import gc
3
  import json
4
  import os
5
  from pathlib import Path
 
9
 
10
  import ffmpeg
11
 
 
 
 
12
  class DiarizationEntry:
13
  def __init__(self, start, end, speaker):
14
  self.start = start
 
26
  }
27
 
28
  class Diarization:
29
+ def __init__(self, auth_token=None):
30
  if auth_token is None:
31
  auth_token = os.environ.get("HK_ACCESS_TOKEN")
32
  if auth_token is None:
 
35
  self.auth_token = auth_token
36
  self.initialized = False
37
  self.pipeline = None
 
38
 
39
  @staticmethod
40
  def has_libraries():
 
51
  from pyannote.audio import Pipeline
52
 
53
  self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", use_auth_token=self.auth_token)
54
+ self.initialized = True
55
 
56
  # Load GPU mode if available
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
61
  else:
62
  print("Diarization - using CPU")
63
 
64
+ def run(self, audio_file, **kwargs):
65
  self.initialize()
66
  audio_file_obj = Path(audio_file)
67
 
 
76
  except ffmpeg.Error as e:
77
  print(f"Error occurred during audio conversion: {e.stderr}")
78
 
79
+ diarization = self.pipeline(target_file, **kwargs)
80
 
81
  if target_file != audio_file:
82
  # Delete temp file
 
146
  print(f"Output saved to {effective_path}")
147
 
148
  def main():
149
+ from src.utils import write_srt
150
+ from src.diarization.transcriptLoader import load_transcript
151
+
152
  parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
153
  parser.add_argument('audio_file', type=str, help='Input audio file')
154
  parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
 
167
  # Read whisper JSON or SRT file
168
  whisper_result = load_transcript(args.whisper_file)
169
 
170
+ diarization = Diarization(auth_token=args.auth_token)
171
+ diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
172
 
173
  # Print result
174
  print("Diarization result:")
 
186
  lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
187
 
188
  if __name__ == "__main__":
189
+ main()
190
+
191
+ #test = Diarization()
192
+ #print("Initializing")
193
+ #test.initialize()
194
+
195
+ #input("Press Enter to continue...")
src/diarization/diarizationContainer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src.diarization.diarization import Diarization, DiarizationEntry
3
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
4
+ from src.vadParallel import ParallelContext
5
+
6
+ class DiarizationContainer:
7
+ def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None):
8
+ self.auth_token = auth_token
9
+ self.enable_daemon_process = enable_daemon_process
10
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
11
+ self.diarization_context: ParallelContext = None
12
+ self.cache = cache
13
+ self.model = None
14
+
15
+ def run(self, audio_file, **kwargs):
16
+ # Create parallel context if needed
17
+ if self.diarization_context is None and self.enable_daemon_process:
18
+ # Number of processes is set to 1 as we mainly use this in order to clean up GPU memory
19
+ self.diarization_context = ParallelContext(num_processes=1)
20
+
21
+ # Run directly
22
+ if self.diarization_context is None:
23
+ return self.execute(audio_file, **kwargs)
24
+
25
+ # Otherwise run in a separate process
26
+ pool = self.diarization_context.get_pool()
27
+
28
+ try:
29
+ result = pool.apply(self.execute, (audio_file,), kwargs)
30
+ return result
31
+ finally:
32
+ self.diarization_context.return_pool(pool)
33
+
34
+ def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
35
+ if self.model is not None:
36
+ return self.model.mark_speakers(diarization_result, whisper_result)
37
+
38
+ # Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
39
+ model = Diarization(self.auth_token)
40
+ return model.mark_speakers(diarization_result, whisper_result)
41
+
42
+ def get_model(self):
43
+ # Lazy load the model
44
+ if (self.model is None):
45
+ if self.cache:
46
+ print("Loading diarization model from cache")
47
+ self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token))
48
+ else:
49
+ print("Loading diarization model")
50
+ self.model = Diarization(self.auth_token)
51
+ return self.model
52
+
53
+ def execute(self, audio_file, **kwargs):
54
+ model = self.get_model()
55
+
56
+ # We must use list() here to force the iterator to run, as generators are not picklable
57
+ result = list(model.run(audio_file, **kwargs))
58
+ return result
59
+
60
+ def cleanup(self):
61
+ if self.diarization_context is not None:
62
+ self.diarization_context.close()
63
+
64
+ def __getstate__(self):
65
+ return {
66
+ "auth_token": self.auth_token,
67
+ "enable_daemon_process": self.enable_daemon_process,
68
+ "auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds
69
+ }
70
+
71
+ def __setstate__(self, state):
72
+ self.auth_token = state["auth_token"]
73
+ self.enable_daemon_process = state["enable_daemon_process"]
74
+ self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
75
+ self.diarization_context = None
76
+ self.cache = GLOBAL_MODEL_CACHE
77
+ self.model = None