whisper-webui-translate / src /diarization /diarizationContainer.py
aadnk's picture
Ensure GPU memory in diarization can be cleaned up
18bb72f
raw
history blame
No virus
3.21 kB
from typing import List
from src.diarization.diarization import Diarization, DiarizationEntry
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
from src.vadParallel import ParallelContext
class DiarizationContainer:
def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None):
self.auth_token = auth_token
self.enable_daemon_process = enable_daemon_process
self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
self.diarization_context: ParallelContext = None
self.cache = cache
self.model = None
def run(self, audio_file, **kwargs):
# Create parallel context if needed
if self.diarization_context is None and self.enable_daemon_process:
# Number of processes is set to 1 as we mainly use this in order to clean up GPU memory
self.diarization_context = ParallelContext(num_processes=1)
# Run directly
if self.diarization_context is None:
return self.execute(audio_file, **kwargs)
# Otherwise run in a separate process
pool = self.diarization_context.get_pool()
try:
result = pool.apply(self.execute, (audio_file,), kwargs)
return result
finally:
self.diarization_context.return_pool(pool)
def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
if self.model is not None:
return self.model.mark_speakers(diarization_result, whisper_result)
# Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
model = Diarization(self.auth_token)
return model.mark_speakers(diarization_result, whisper_result)
def get_model(self):
# Lazy load the model
if (self.model is None):
if self.cache:
print("Loading diarization model from cache")
self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token))
else:
print("Loading diarization model")
self.model = Diarization(self.auth_token)
return self.model
def execute(self, audio_file, **kwargs):
model = self.get_model()
# We must use list() here to force the iterator to run, as generators are not picklable
result = list(model.run(audio_file, **kwargs))
return result
def cleanup(self):
if self.diarization_context is not None:
self.diarization_context.close()
def __getstate__(self):
return {
"auth_token": self.auth_token,
"enable_daemon_process": self.enable_daemon_process,
"auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds
}
def __setstate__(self, state):
self.auth_token = state["auth_token"]
self.enable_daemon_process = state["enable_daemon_process"]
self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
self.diarization_context = None
self.cache = GLOBAL_MODEL_CACHE
self.model = None