|
import os |
|
import dotenv |
|
from pyannote.audio import Pipeline |
|
import torch |
|
import torchaudio |
|
|
|
dotenv.load_dotenv() |
|
SUBTIFY_TOKEN = os.getenv("SUBTIFY_TOKEN") |
|
|
|
def diarize(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list: |
|
""" |
|
Diarize an audio file using Pyannote. |
|
|
|
Args: |
|
audio_path (str): The path to the audio file to diarize. |
|
|
|
Returns: |
|
list: A list of segments with start, duration, end, and speaker. |
|
""" |
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
params = {} |
|
if num_speakers > 0: |
|
params["num_speakers"] = num_speakers |
|
if min_speakers > 0: |
|
params["min_speakers"] = min_speakers |
|
if max_speakers > 0: |
|
params["max_speakers"] = max_speakers |
|
|
|
|
|
device = torch.device(device) |
|
|
|
|
|
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=SUBTIFY_TOKEN) |
|
pipeline.to(device) |
|
|
|
|
|
diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate}, **params) |
|
|
|
return diarization |
|
|
|
def parse_rttm(rttm_string): |
|
""" |
|
Parse an RTTM string into a list of segments. |
|
|
|
Args: |
|
rttm_string (str): The RTTM string to parse. |
|
|
|
Returns: |
|
list: A list of segments with start, duration, end, and speaker. |
|
""" |
|
|
|
|
|
segments = [] |
|
|
|
|
|
for line in rttm_string.strip().split('\n'): |
|
|
|
parts = line.split() |
|
|
|
|
|
segment = { |
|
'start': float(parts[3]), |
|
'duration': float(parts[4]), |
|
'end': float(parts[3]) + float(parts[4]), |
|
'speaker': parts[7] |
|
} |
|
|
|
|
|
segments.append(segment) |
|
return segments |
|
|
|
def diarize_audio(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list: |
|
""" |
|
Diarize an audio file using Pyannote. |
|
|
|
Args: |
|
audio_path (str): The path to the audio file to diarize. |
|
|
|
Returns: |
|
list: A list of segments with start, duration, end, and speaker. |
|
""" |
|
|
|
|
|
diarization = diarize(audio_path, num_speakers, min_speakers, max_speakers, device) |
|
|
|
|
|
rttm_output = diarization.to_rttm() |
|
|
|
|
|
return parse_rttm(rttm_output) |
|
|