whisper-webui-translate / src /diarization /diarizationContainer.py
aadnk's picture
Ensure GPU memory in diarization can be cleaned up
18bb72f
raw
history blame
3.21 kB
from typing import List
from src.diarization.diarization import Diarization, DiarizationEntry
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
from src.vadParallel import ParallelContext
class DiarizationContainer:
def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None):
self.auth_token = auth_token
self.enable_daemon_process = enable_daemon_process
self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
self.diarization_context: ParallelContext = None
self.cache = cache
self.model = None
def run(self, audio_file, **kwargs):
# Create parallel context if needed
if self.diarization_context is None and self.enable_daemon_process:
# Number of processes is set to 1 as we mainly use this in order to clean up GPU memory
self.diarization_context = ParallelContext(num_processes=1)
# Run directly
if self.diarization_context is None:
return self.execute(audio_file, **kwargs)
# Otherwise run in a separate process
pool = self.diarization_context.get_pool()
try:
result = pool.apply(self.execute, (audio_file,), kwargs)
return result
finally:
self.diarization_context.return_pool(pool)
def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
if self.model is not None:
return self.model.mark_speakers(diarization_result, whisper_result)
# Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
model = Diarization(self.auth_token)
return model.mark_speakers(diarization_result, whisper_result)
def get_model(self):
# Lazy load the model
if (self.model is None):
if self.cache:
print("Loading diarization model from cache")
self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token))
else:
print("Loading diarization model")
self.model = Diarization(self.auth_token)
return self.model
def execute(self, audio_file, **kwargs):
model = self.get_model()
# We must use list() here to force the iterator to run, as generators are not picklable
result = list(model.run(audio_file, **kwargs))
return result
def cleanup(self):
if self.diarization_context is not None:
self.diarization_context.close()
def __getstate__(self):
return {
"auth_token": self.auth_token,
"enable_daemon_process": self.enable_daemon_process,
"auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds
}
def __setstate__(self, state):
self.auth_token = state["auth_token"]
self.enable_daemon_process = state["enable_daemon_process"]
self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
self.diarization_context = None
self.cache = GLOBAL_MODEL_CACHE
self.model = None