|
import os |
|
import torch |
|
from typing import List, Union, BinaryIO, Optional |
|
import numpy as np |
|
import time |
|
import logging |
|
|
|
from modules.utils.paths import DIARIZATION_MODELS_DIR |
|
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers |
|
from modules.diarize.audio_loader import load_audio |
|
|
|
|
|
class Diarizer: |
|
def __init__(self, |
|
model_dir: str = DIARIZATION_MODELS_DIR |
|
): |
|
self.device = self.get_device() |
|
self.available_device = self.get_available_device() |
|
self.compute_type = "float16" |
|
self.model_dir = model_dir |
|
os.makedirs(self.model_dir, exist_ok=True) |
|
self.pipe = None |
|
|
|
def run(self, |
|
audio: Union[str, BinaryIO, np.ndarray], |
|
transcribed_result: List[dict], |
|
use_auth_token: str, |
|
device: Optional[str] = None |
|
): |
|
""" |
|
Diarize transcribed result as a post-processing |
|
|
|
Parameters |
|
---------- |
|
audio: Union[str, BinaryIO, np.ndarray] |
|
Audio input. This can be file path or binary type. |
|
transcribed_result: List[dict] |
|
transcribed result through whisper. |
|
use_auth_token: str |
|
Huggingface token with READ permission. This is only needed the first time you download the model. |
|
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. |
|
device: Optional[str] |
|
Device for diarization. |
|
|
|
Returns |
|
---------- |
|
segments_result: List[dict] |
|
list of dicts that includes start, end timestamps and transcribed text |
|
elapsed_time: float |
|
elapsed time for running |
|
""" |
|
start_time = time.time() |
|
|
|
if device is None: |
|
device = self.device |
|
|
|
if device != self.device or self.pipe is None: |
|
self.update_pipe( |
|
device=device, |
|
use_auth_token=use_auth_token |
|
) |
|
|
|
audio = load_audio(audio) |
|
|
|
diarization_segments = self.pipe(audio) |
|
diarized_result = assign_word_speakers( |
|
diarization_segments, |
|
{"segments": transcribed_result} |
|
) |
|
|
|
for segment in diarized_result["segments"]: |
|
speaker = "None" |
|
if "speaker" in segment: |
|
speaker = segment["speaker"] |
|
speaker = speaker.replace("_", " ") |
|
segment["text"] = speaker + ": " + segment["text"].strip() |
|
|
|
elapsed_time = time.time() - start_time |
|
return diarized_result["segments"], elapsed_time |
|
|
|
def update_pipe(self, |
|
use_auth_token: str, |
|
device: str |
|
): |
|
""" |
|
Set pipeline for diarization |
|
|
|
Parameters |
|
---------- |
|
use_auth_token: str |
|
Huggingface token with READ permission. This is only needed the first time you download the model. |
|
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. |
|
device: str |
|
Device for diarization. |
|
""" |
|
self.device = device |
|
|
|
os.makedirs(self.model_dir, exist_ok=True) |
|
|
|
if (not os.listdir(self.model_dir) and |
|
not use_auth_token): |
|
print( |
|
"\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n" |
|
"Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n" |
|
) |
|
return |
|
|
|
logger = logging.getLogger("speechbrain.utils.train_logger") |
|
|
|
logger.disabled = True |
|
self.pipe = DiarizationPipeline( |
|
use_auth_token=use_auth_token, |
|
device=device, |
|
cache_dir=self.model_dir |
|
) |
|
logger.disabled = False |
|
|
|
@staticmethod |
|
def get_device(): |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
@staticmethod |
|
def get_available_device(): |
|
devices = ["cpu"] |
|
if torch.cuda.is_available(): |
|
devices.append("cuda") |
|
elif torch.backends.mps.is_available(): |
|
devices.append("mps") |
|
return devices |