| | import torch
|
| | import torchaudio
|
| | import numpy as np
|
| | import os
|
| | import warnings
|
| | from pathlib import Path
|
| | from typing import Dict, List, Tuple
|
| | import argparse
|
| | from concurrent.futures import ThreadPoolExecutor
|
| | import gc
|
| | import logging
|
| |
|
| | verbose_output = True
|
| |
|
| |
|
| | warnings.filterwarnings("ignore", category=UserWarning, module="pyannote.audio.models.blocks.pooling")
|
| | warnings.filterwarnings("ignore", message=".*TensorFloat-32.*", category=UserWarning)
|
| | warnings.filterwarnings("ignore", message=".*std\\(\\): degrees of freedom.*", category=UserWarning)
|
| | warnings.filterwarnings("ignore", message=".*speechbrain.pretrained.*was deprecated.*", category=UserWarning)
|
| | warnings.filterwarnings("ignore", message=".*Module 'speechbrain.pretrained'.*", category=UserWarning)
|
| |
|
| |
|
| | os.environ["SB_LOG_LEVEL"] = "WARNING"
|
| | import speechbrain
|
| |
|
| | def xprint(t = None):
|
| | if verbose_output:
|
| | print(t)
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| | torch.backends.cudnn.allow_tf32 = True
|
| |
|
| | try:
|
| | from pyannote.audio import Pipeline
|
| | PYANNOTE_AVAILABLE = True
|
| | except ImportError:
|
| | PYANNOTE_AVAILABLE = False
|
| | print("Install: pip install pyannote.audio")
|
| |
|
| |
|
| | class OptimizedPyannote31SpeakerSeparator:
|
| | def __init__(self, hf_token: str = None, local_model_path: str = None,
|
| | vad_onset: float = 0.2, vad_offset: float = 0.8):
|
| | """
|
| | Initialize with Pyannote 3.1 pipeline with tunable VAD sensitivity.
|
| | """
|
| | embedding_path = "ckpts/pyannote/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin"
|
| | segmentation_path = "ckpts/pyannote/pytorch_model_segmentation-3.0.bin"
|
| |
|
| |
|
| | xprint(f"Loading segmentation model from: {segmentation_path}")
|
| | xprint(f"Loading embedding model from: {embedding_path}")
|
| |
|
| | try:
|
| | from pyannote.audio import Model
|
| | from pyannote.audio.pipelines import SpeakerDiarization
|
| |
|
| |
|
| | segmentation_model = Model.from_pretrained(segmentation_path)
|
| | embedding_model = Model.from_pretrained(embedding_path)
|
| | xprint("Models loaded successfully!")
|
| |
|
| |
|
| | self.pipeline = SpeakerDiarization(
|
| | segmentation=segmentation_model,
|
| | embedding=embedding_model,
|
| | clustering='AgglomerativeClustering'
|
| | )
|
| |
|
| |
|
| | self.pipeline.instantiate({
|
| | 'clustering': {
|
| | 'method': 'centroid',
|
| | 'min_cluster_size': 12,
|
| | 'threshold': 0.7045654963945799
|
| | },
|
| | 'segmentation': {
|
| | 'min_duration_off': 0.0
|
| | }
|
| | })
|
| | xprint("Pipeline instantiated successfully!")
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | xprint("CUDA available, moving pipeline to GPU...")
|
| | self.pipeline.to(torch.device("cuda"))
|
| | else:
|
| | xprint("CUDA not available, using CPU...")
|
| |
|
| | except Exception as e:
|
| | xprint(f"Error loading pipeline: {e}")
|
| | xprint(f"Error type: {type(e)}")
|
| | import traceback
|
| | traceback.print_exc()
|
| | raise
|
| |
|
| |
|
| | self.hf_token = hf_token
|
| | self._overlap_pipeline = None
|
| |
|
| | def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]:
|
| | """Optimized main separation function with memory management."""
|
| | xprint("Starting optimized audio separation...")
|
| | self._current_audio_path = os.path.abspath(audio_path)
|
| |
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| |
|
| |
|
| | waveform, sample_rate = self.load_audio(audio_path)
|
| |
|
| |
|
| | diarization = self.perform_optimized_diarization(audio_path)
|
| |
|
| |
|
| | masks = self.create_optimized_speaker_masks(diarization, waveform.shape[1], sample_rate)
|
| |
|
| |
|
| | final_masks = self.apply_optimized_background_preservation(masks, waveform.shape[1])
|
| |
|
| |
|
| | del masks
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| | gc.collect()
|
| |
|
| |
|
| | if audio_original_path is None:
|
| | waveform_original = waveform
|
| | else:
|
| | waveform_original, sample_rate = self.load_audio(audio_original_path)
|
| | output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2)
|
| |
|
| | return output_paths
|
| |
|
| | def _extract_both_speaking_regions(
|
| | self,
|
| | diarization,
|
| | audio_length: int,
|
| | sample_rate: int
|
| | ) -> np.ndarray:
|
| | """
|
| | Detect regions where ≥2 speakers talk simultaneously
|
| | using pyannote/overlapped-speech-detection.
|
| | Falls back to manual pair-wise detection if the model
|
| | is unavailable.
|
| | """
|
| | xprint("Extracting overlap with dedicated pipeline…")
|
| | both_speaking_mask = np.zeros(audio_length, dtype=bool)
|
| |
|
| |
|
| |
|
| | overlap_pipeline = None
|
| |
|
| |
|
| |
|
| | audio_uri = getattr(self, "_current_audio_path", None) \
|
| | or getattr(diarization, "uri", None)
|
| | if overlap_pipeline and audio_uri:
|
| | try:
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | overlap_annotation = overlap_pipeline(audio_uri)
|
| |
|
| | for seg in overlap_annotation.get_timeline().support():
|
| | s = max(0, int(seg.start * sample_rate))
|
| | e = min(audio_length, int(seg.end * sample_rate))
|
| | if s < e:
|
| | both_speaking_mask[s:e] = True
|
| | t = np.sum(both_speaking_mask) / sample_rate
|
| | xprint(f" Found {t:.1f}s of overlapped speech (model) ")
|
| | return both_speaking_mask
|
| | except Exception as e:
|
| | xprint(f" ⚠ Overlap model failed: {e}")
|
| |
|
| |
|
| | xprint(" Falling back to manual overlap detection…")
|
| | timeline_tracks = list(diarization.itertracks(yield_label=True))
|
| | for i, (turn1, _, spk1) in enumerate(timeline_tracks):
|
| | for j, (turn2, _, spk2) in enumerate(timeline_tracks):
|
| | if i >= j or spk1 == spk2:
|
| | continue
|
| | o_start, o_end = max(turn1.start, turn2.start), min(turn1.end, turn2.end)
|
| | if o_start < o_end:
|
| | s = max(0, int(o_start * sample_rate))
|
| | e = min(audio_length, int(o_end * sample_rate))
|
| | if s < e:
|
| | both_speaking_mask[s:e] = True
|
| | t = np.sum(both_speaking_mask) / sample_rate
|
| | xprint(f" Found {t:.1f}s of overlapped speech (manual) ")
|
| | return both_speaking_mask
|
| |
|
| | def _configure_vad(self, vad_onset: float, vad_offset: float):
|
| | """Configure VAD parameters efficiently."""
|
| | xprint("Applying more sensitive VAD parameters...")
|
| | try:
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| |
|
| | if hasattr(self.pipeline, '_vad'):
|
| | self.pipeline._vad.instantiate({
|
| | "onset": vad_onset,
|
| | "offset": vad_offset,
|
| | "min_duration_on": 0.1,
|
| | "min_duration_off": 0.1,
|
| | "pad_onset": 0.1,
|
| | "pad_offset": 0.1,
|
| | })
|
| | xprint(f"✓ VAD parameters updated: onset={vad_onset}, offset={vad_offset}")
|
| | else:
|
| | xprint("⚠ Could not access VAD component directly")
|
| | except Exception as e:
|
| | xprint(f"⚠ Could not modify VAD parameters: {e}")
|
| |
|
| | def _get_overlap_pipeline(self):
|
| | """
|
| | Build a pyannote-3-native OverlappedSpeechDetection pipeline.
|
| |
|
| | • uses the open-licence `pyannote/segmentation-3.0` checkpoint
|
| | • only `min_duration_on/off` can be tuned (API 3.x)
|
| | """
|
| | if self._overlap_pipeline is not None:
|
| | return None if self._overlap_pipeline is False else self._overlap_pipeline
|
| |
|
| | try:
|
| | from pyannote.audio.pipelines import OverlappedSpeechDetection
|
| |
|
| | xprint("Building OverlappedSpeechDetection with segmentation-3.0…")
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| |
|
| |
|
| | ods = OverlappedSpeechDetection(
|
| | segmentation="pyannote/segmentation-3.0"
|
| | )
|
| |
|
| |
|
| | ods.instantiate({
|
| | "min_duration_on": 0.06,
|
| | "min_duration_off": 0.10,
|
| | })
|
| |
|
| | if torch.cuda.is_available():
|
| | ods.to(torch.device("cuda"))
|
| |
|
| | self._overlap_pipeline = ods
|
| | xprint("✓ Overlap pipeline ready (segmentation-3.0)")
|
| | return ods
|
| |
|
| | except Exception as e:
|
| | xprint(f"⚠ Could not build overlap pipeline ({e}). "
|
| | "Falling back to manual pair-wise detection.")
|
| | self._overlap_pipeline = False
|
| | return None
|
| |
|
| | def _xprint_setup_instructions(self):
|
| | """xprint setup instructions."""
|
| | xprint("\nTo use Pyannote 3.1:")
|
| | xprint("1. Get token: https://huggingface.co/settings/tokens")
|
| | xprint("2. Accept terms: https://huggingface.co/pyannote/speaker-diarization-3.1")
|
| | xprint("3. Run with: --token YOUR_TOKEN")
|
| |
|
| | def load_audio(self, audio_path: str) -> Tuple[torch.Tensor, int]:
|
| | """Load and preprocess audio efficiently."""
|
| | xprint(f"Loading audio: {audio_path}")
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | waveform, sample_rate = torchaudio.load(audio_path)
|
| |
|
| |
|
| | if waveform.shape[0] > 1:
|
| | waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| |
|
| | xprint(f"Audio: {waveform.shape[1]} samples at {sample_rate}Hz")
|
| | return waveform, sample_rate
|
| |
|
| | def perform_optimized_diarization(self, audio_path: str) -> object:
|
| | """
|
| | Optimized diarization with efficient parameter testing.
|
| | """
|
| | xprint("Running optimized Pyannote 3.1 diarization...")
|
| |
|
| |
|
| | strategies = [
|
| | {"min_speakers": 2, "max_speakers": 2},
|
| | {"num_speakers": 2},
|
| | {"min_speakers": 2, "max_speakers": 3},
|
| | {"min_speakers": 1, "max_speakers": 2},
|
| | {"min_speakers": 2, "max_speakers": 4},
|
| | {}
|
| | ]
|
| |
|
| | for i, params in enumerate(strategies):
|
| | try:
|
| | xprint(f"Strategy {i+1}: {params}")
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | diarization = self.pipeline(audio_path, **params)
|
| |
|
| | speakers = list(diarization.labels())
|
| | speaker_count = len(speakers)
|
| |
|
| | xprint(f" → Detected {speaker_count} speakers: {speakers}")
|
| |
|
| |
|
| | if speaker_count >= 2:
|
| | xprint(f"✓ Success with strategy {i+1}! Using {speaker_count} speakers")
|
| | return diarization
|
| | elif speaker_count == 1 and i == 0:
|
| |
|
| | fallback_diarization = diarization
|
| |
|
| | except Exception as e:
|
| | xprint(f" Strategy {i+1} failed: {e}")
|
| | continue
|
| |
|
| |
|
| | if 'fallback_diarization' in locals():
|
| | xprint("Attempting aggressive clustering for single speaker...")
|
| | try:
|
| | aggressive_diarization = self._try_aggressive_clustering(audio_path)
|
| | if aggressive_diarization and len(list(aggressive_diarization.labels())) >= 2:
|
| | return aggressive_diarization
|
| | except Exception as e:
|
| | xprint(f"Aggressive clustering failed: {e}")
|
| |
|
| | xprint("Using single speaker result")
|
| | return fallback_diarization
|
| |
|
| |
|
| | xprint("Last resort: running without constraints...")
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | return self.pipeline(audio_path)
|
| |
|
| | def _try_aggressive_clustering(self, audio_path: str) -> object:
|
| | """Try aggressive clustering parameters."""
|
| | try:
|
| | from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| |
|
| |
|
| | temp_pipeline = SpeakerDiarization(
|
| | segmentation=self.pipeline.segmentation,
|
| | embedding=self.pipeline.embedding,
|
| | clustering="AgglomerativeClustering"
|
| | )
|
| |
|
| | temp_pipeline.instantiate({
|
| | "clustering": {
|
| | "method": "centroid",
|
| | "min_cluster_size": 1,
|
| | "threshold": 0.1,
|
| | },
|
| | "segmentation": {
|
| | "min_duration_off": 0.0,
|
| | "min_duration_on": 0.1,
|
| | }
|
| | })
|
| |
|
| | return temp_pipeline(audio_path, min_speakers=2)
|
| |
|
| | except Exception as e:
|
| | xprint(f"Aggressive clustering setup failed: {e}")
|
| | return None
|
| |
|
| | def create_optimized_speaker_masks(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
| | """Optimized mask creation using vectorized operations."""
|
| | xprint("Creating optimized speaker masks...")
|
| |
|
| | speakers = list(diarization.labels())
|
| | xprint(f"Processing speakers: {speakers}")
|
| |
|
| |
|
| | if len(speakers) == 0:
|
| | xprint("⚠ No speakers detected, creating dummy masks")
|
| | return self._create_dummy_masks(audio_length)
|
| |
|
| | if len(speakers) == 1:
|
| | xprint("⚠ Only 1 speaker detected, creating temporal split")
|
| | return self._create_optimized_temporal_split(diarization, audio_length, sample_rate)
|
| |
|
| |
|
| | both_speaking_regions = self._extract_both_speaking_regions(diarization, audio_length, sample_rate)
|
| |
|
| |
|
| | masks = {}
|
| |
|
| |
|
| | for speaker in speakers:
|
| |
|
| | segments = []
|
| | speaker_timeline = diarization.label_timeline(speaker)
|
| | for segment in speaker_timeline:
|
| | start_sample = max(0, int(segment.start * sample_rate))
|
| | end_sample = min(audio_length, int(segment.end * sample_rate))
|
| | if start_sample < end_sample:
|
| | segments.append((start_sample, end_sample))
|
| |
|
| |
|
| | if segments:
|
| | mask = self._create_mask_vectorized(segments, audio_length)
|
| | masks[speaker] = mask
|
| | speaking_time = np.sum(mask) / sample_rate
|
| | xprint(f" {speaker}: {speaking_time:.1f}s speaking time")
|
| | else:
|
| | masks[speaker] = np.zeros(audio_length, dtype=np.float32)
|
| |
|
| |
|
| | self._both_speaking_regions = both_speaking_regions
|
| |
|
| | return masks
|
| |
|
| | def _create_mask_vectorized(self, segments: List[Tuple[int, int]], audio_length: int) -> np.ndarray:
|
| | """Create mask using vectorized operations."""
|
| | mask = np.zeros(audio_length, dtype=np.float32)
|
| |
|
| | if not segments:
|
| | return mask
|
| |
|
| |
|
| | segments_array = np.array(segments)
|
| | starts = segments_array[:, 0]
|
| | ends = segments_array[:, 1]
|
| |
|
| |
|
| | for start, end in zip(starts, ends):
|
| | mask[start:end] = 1.0
|
| |
|
| | return mask
|
| |
|
| | def _create_dummy_masks(self, audio_length: int) -> Dict[str, np.ndarray]:
|
| | """Create dummy masks for edge cases."""
|
| | return {
|
| | "SPEAKER_00": np.ones(audio_length, dtype=np.float32) * 0.5,
|
| | "SPEAKER_01": np.ones(audio_length, dtype=np.float32) * 0.5
|
| | }
|
| |
|
| | def _create_optimized_temporal_split(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
| | """Optimized temporal split with vectorized operations."""
|
| | xprint("Creating optimized temporal split...")
|
| |
|
| |
|
| | segments = []
|
| | for turn, _, speaker in diarization.itertracks(yield_label=True):
|
| | segments.append((turn.start, turn.end))
|
| |
|
| | segments.sort()
|
| | xprint(f"Found {len(segments)} speech segments")
|
| |
|
| | if len(segments) <= 1:
|
| |
|
| | return self._create_simple_split(audio_length)
|
| |
|
| |
|
| | segment_array = np.array(segments)
|
| | gaps = segment_array[1:, 0] - segment_array[:-1, 1]
|
| |
|
| | if len(gaps) > 0:
|
| | longest_gap_idx = np.argmax(gaps)
|
| | longest_gap_duration = gaps[longest_gap_idx]
|
| |
|
| | xprint(f"Longest gap: {longest_gap_duration:.1f}s after segment {longest_gap_idx+1}")
|
| |
|
| | if longest_gap_duration > 1.0:
|
| |
|
| | split_point = longest_gap_idx + 1
|
| | xprint(f"Splitting at natural break: segments 1-{split_point} vs {split_point+1}-{len(segments)}")
|
| |
|
| | return self._create_split_masks(segments, split_point, audio_length, sample_rate)
|
| |
|
| |
|
| | xprint("Using alternating assignment...")
|
| | return self._create_alternating_masks(segments, audio_length, sample_rate)
|
| |
|
| | def _create_simple_split(self, audio_length: int) -> Dict[str, np.ndarray]:
|
| | """Simple temporal split in half."""
|
| | mid_point = audio_length // 2
|
| | masks = {
|
| | "SPEAKER_00": np.zeros(audio_length, dtype=np.float32),
|
| | "SPEAKER_01": np.zeros(audio_length, dtype=np.float32)
|
| | }
|
| | masks["SPEAKER_00"][:mid_point] = 1.0
|
| | masks["SPEAKER_01"][mid_point:] = 1.0
|
| | return masks
|
| |
|
| | def _create_split_masks(self, segments: List[Tuple[float, float]], split_point: int,
|
| | audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
| | """Create masks with split at specific point."""
|
| | masks = {
|
| | "SPEAKER_00": np.zeros(audio_length, dtype=np.float32),
|
| | "SPEAKER_01": np.zeros(audio_length, dtype=np.float32)
|
| | }
|
| |
|
| |
|
| | for i, (start_time, end_time) in enumerate(segments):
|
| | start_sample = max(0, int(start_time * sample_rate))
|
| | end_sample = min(audio_length, int(end_time * sample_rate))
|
| |
|
| | if start_sample < end_sample:
|
| | speaker_key = "SPEAKER_00" if i < split_point else "SPEAKER_01"
|
| | masks[speaker_key][start_sample:end_sample] = 1.0
|
| |
|
| | return masks
|
| |
|
| | def _create_alternating_masks(self, segments: List[Tuple[float, float]],
|
| | audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
| | """Create masks with alternating assignment."""
|
| | masks = {
|
| | "SPEAKER_00": np.zeros(audio_length, dtype=np.float32),
|
| | "SPEAKER_01": np.zeros(audio_length, dtype=np.float32)
|
| | }
|
| |
|
| | for i, (start_time, end_time) in enumerate(segments):
|
| | start_sample = max(0, int(start_time * sample_rate))
|
| | end_sample = min(audio_length, int(end_time * sample_rate))
|
| |
|
| | if start_sample < end_sample:
|
| | speaker_key = f"SPEAKER_0{i % 2}"
|
| | masks[speaker_key][start_sample:end_sample] = 1.0
|
| |
|
| | return masks
|
| |
|
| | def apply_optimized_background_preservation(self, masks: Dict[str, np.ndarray],
|
| | audio_length: int) -> Dict[str, np.ndarray]:
|
| | """
|
| | Heavily optimized background preservation using pure vectorized operations.
|
| | """
|
| | xprint("Applying optimized voice separation logic...")
|
| |
|
| |
|
| | speaker_keys = self._get_top_speakers(masks, audio_length)
|
| |
|
| |
|
| | final_masks = {
|
| | speaker: np.zeros(audio_length, dtype=np.float32)
|
| | for speaker in speaker_keys
|
| | }
|
| |
|
| |
|
| | active_0 = masks.get(speaker_keys[0], np.zeros(audio_length)) > 0.5
|
| | active_1 = masks.get(speaker_keys[1], np.zeros(audio_length)) > 0.5
|
| |
|
| |
|
| | both_active = active_0 & active_1
|
| | only_0 = active_0 & ~active_1
|
| | only_1 = ~active_0 & active_1
|
| | neither = ~active_0 & ~active_1
|
| |
|
| |
|
| | final_masks[speaker_keys[0]][both_active] = 1.0
|
| | final_masks[speaker_keys[1]][both_active] = 1.0
|
| |
|
| | final_masks[speaker_keys[0]][only_0] = 1.0
|
| | final_masks[speaker_keys[1]][only_0] = 0.0
|
| |
|
| | final_masks[speaker_keys[0]][only_1] = 0.0
|
| | final_masks[speaker_keys[1]][only_1] = 1.0
|
| |
|
| |
|
| | if np.any(neither):
|
| | ambiguous_assignments = self._compute_ambiguous_assignments_vectorized(
|
| | masks, speaker_keys, neither, audio_length
|
| | )
|
| |
|
| |
|
| | final_masks[speaker_keys[0]][neither] = (ambiguous_assignments == 0).astype(np.float32) * 0.5
|
| | final_masks[speaker_keys[1]][neither] = (ambiguous_assignments == 1).astype(np.float32) * 0.5
|
| |
|
| |
|
| | sample_rate = 16000
|
| | xprint(f" Both speaking clearly: {np.sum(both_active)/sample_rate:.1f}s")
|
| | xprint(f" {speaker_keys[0]} only: {np.sum(only_0)/sample_rate:.1f}s")
|
| | xprint(f" {speaker_keys[1]} only: {np.sum(only_1)/sample_rate:.1f}s")
|
| | xprint(f" Ambiguous (assigned): {np.sum(neither)/sample_rate:.1f}s")
|
| |
|
| |
|
| | final_masks = self._apply_minimum_duration_smoothing(final_masks, sample_rate)
|
| |
|
| | return final_masks
|
| |
|
| | def _get_top_speakers(self, masks: Dict[str, np.ndarray], audio_length: int) -> List[str]:
|
| | """Get top 2 speakers by speaking time."""
|
| | speaker_keys = list(masks.keys())
|
| |
|
| | if len(speaker_keys) > 2:
|
| |
|
| | speaking_times = {k: np.sum(v) for k, v in masks.items()}
|
| | speaker_keys = sorted(speaking_times.keys(), key=lambda x: speaking_times[x], reverse=True)[:2]
|
| | xprint(f"Keeping top 2 speakers: {speaker_keys}")
|
| | elif len(speaker_keys) == 1:
|
| | speaker_keys.append("SPEAKER_SILENT")
|
| |
|
| | return speaker_keys
|
| |
|
| | def _compute_ambiguous_assignments_vectorized(self, masks: Dict[str, np.ndarray],
|
| | speaker_keys: List[str],
|
| | ambiguous_mask: np.ndarray,
|
| | audio_length: int) -> np.ndarray:
|
| | """Compute speaker assignments for ambiguous regions using vectorized operations."""
|
| | ambiguous_indices = np.where(ambiguous_mask)[0]
|
| |
|
| | if len(ambiguous_indices) == 0:
|
| | return np.array([])
|
| |
|
| |
|
| | speaker_segments = {}
|
| | for speaker in speaker_keys:
|
| | if speaker in masks and speaker != "SPEAKER_SILENT":
|
| | mask = masks[speaker] > 0.5
|
| |
|
| | diff = np.diff(np.concatenate(([False], mask, [False])).astype(int))
|
| | starts = np.where(diff == 1)[0]
|
| | ends = np.where(diff == -1)[0]
|
| | speaker_segments[speaker] = np.column_stack([starts, ends])
|
| | else:
|
| | speaker_segments[speaker] = np.array([]).reshape(0, 2)
|
| |
|
| |
|
| | distances = {}
|
| | for speaker in speaker_keys:
|
| | segments = speaker_segments[speaker]
|
| | if len(segments) == 0:
|
| | distances[speaker] = np.full(len(ambiguous_indices), np.inf)
|
| | else:
|
| |
|
| | distances[speaker] = self._compute_distances_to_segments(ambiguous_indices, segments)
|
| |
|
| |
|
| | assignments = self._assign_based_on_distance(
|
| | distances, speaker_keys, ambiguous_indices, audio_length
|
| | )
|
| |
|
| | return assignments
|
| |
|
| | def _apply_minimum_duration_smoothing(self, masks: Dict[str, np.ndarray],
|
| | sample_rate: int, min_duration_ms: int = 600) -> Dict[str, np.ndarray]:
|
| | """
|
| | Apply minimum duration smoothing with STRICT timer enforcement.
|
| | Uses original both-speaking regions from diarization.
|
| | """
|
| | xprint(f"Applying STRICT minimum duration smoothing ({min_duration_ms}ms)...")
|
| |
|
| | min_samples = int(min_duration_ms * sample_rate / 1000)
|
| | speaker_keys = list(masks.keys())
|
| |
|
| | if len(speaker_keys) != 2:
|
| | return masks
|
| |
|
| | mask0 = masks[speaker_keys[0]]
|
| | mask1 = masks[speaker_keys[1]]
|
| |
|
| |
|
| | both_speaking_original = getattr(self, '_both_speaking_regions', np.zeros(len(mask0), dtype=bool))
|
| |
|
| |
|
| | ambiguous_original = (mask0 < 0.3) & (mask1 < 0.3) & ~both_speaking_original
|
| |
|
| |
|
| | remaining_mask = ~both_speaking_original & ~ambiguous_original
|
| | speaker0_dominant = (mask0 > mask1) & remaining_mask
|
| | speaker1_dominant = (mask1 > mask0) & remaining_mask
|
| |
|
| |
|
| |
|
| | preference_signal = np.full(len(mask0), -1, dtype=int)
|
| | preference_signal[speaker0_dominant] = 0
|
| | preference_signal[speaker1_dominant] = 1
|
| | preference_signal[both_speaking_original] = 2
|
| |
|
| |
|
| | smoothed_assignment = np.full(len(mask0), -1, dtype=int)
|
| | corrections = 0
|
| |
|
| |
|
| | current_state = -1
|
| | samples_remaining = 0
|
| |
|
| |
|
| | for i in range(len(preference_signal)):
|
| | preference = preference_signal[i]
|
| |
|
| |
|
| | if samples_remaining > 0:
|
| |
|
| | smoothed_assignment[i] = current_state
|
| | samples_remaining -= 1
|
| |
|
| |
|
| | if preference >= 0 and preference != current_state:
|
| | corrections += 1
|
| |
|
| | else:
|
| |
|
| |
|
| | if preference >= 0:
|
| |
|
| | if current_state != preference:
|
| |
|
| | current_state = preference
|
| | samples_remaining = min_samples - 1
|
| |
|
| | smoothed_assignment[i] = current_state
|
| |
|
| | else:
|
| |
|
| | if current_state >= 0:
|
| |
|
| | smoothed_assignment[i] = current_state
|
| | else:
|
| |
|
| | smoothed_assignment[i] = -1
|
| |
|
| |
|
| | smoothed_masks = {}
|
| |
|
| | for i, speaker in enumerate(speaker_keys):
|
| | new_mask = np.zeros_like(mask0)
|
| |
|
| |
|
| | speaker_regions = smoothed_assignment == i
|
| | new_mask[speaker_regions] = 1.0
|
| |
|
| |
|
| | both_speaking_regions = smoothed_assignment == 2
|
| | new_mask[both_speaking_regions] = 1.0
|
| |
|
| |
|
| | unassigned_ambiguous = smoothed_assignment == -1
|
| | if np.any(unassigned_ambiguous):
|
| |
|
| | original_ambiguous_mask = ambiguous_original & unassigned_ambiguous
|
| | new_mask[original_ambiguous_mask] = masks[speaker][original_ambiguous_mask]
|
| |
|
| | smoothed_masks[speaker] = new_mask
|
| |
|
| |
|
| | both_speaking_time = np.sum(smoothed_assignment == 2) / sample_rate
|
| | speaker0_time = np.sum(smoothed_assignment == 0) / sample_rate
|
| | speaker1_time = np.sum(smoothed_assignment == 1) / sample_rate
|
| | ambiguous_time = np.sum(smoothed_assignment == -1) / sample_rate
|
| |
|
| | xprint(f" Both speaking clearly: {both_speaking_time:.1f}s")
|
| | xprint(f" {speaker_keys[0]} only: {speaker0_time:.1f}s")
|
| | xprint(f" {speaker_keys[1]} only: {speaker1_time:.1f}s")
|
| | xprint(f" Ambiguous (assigned): {ambiguous_time:.1f}s")
|
| | xprint(f" Enforced minimum duration on {corrections} samples ({corrections/sample_rate:.2f}s)")
|
| |
|
| | return smoothed_masks
|
| |
|
| | def _compute_distances_to_segments(self, indices: np.ndarray, segments: np.ndarray) -> np.ndarray:
|
| | """Compute minimum distances from indices to segments (vectorized)."""
|
| | if len(segments) == 0:
|
| | return np.full(len(indices), np.inf)
|
| |
|
| |
|
| | indices_expanded = indices[:, np.newaxis]
|
| | starts = segments[:, 0]
|
| | ends = segments[:, 1]
|
| |
|
| |
|
| | dist_to_start = np.maximum(0, starts - indices_expanded)
|
| | dist_from_end = np.maximum(0, indices_expanded - ends)
|
| |
|
| |
|
| | distances = np.minimum(dist_to_start, dist_from_end)
|
| |
|
| |
|
| | return np.min(distances, axis=1)
|
| |
|
| | def _assign_based_on_distance(self, distances: Dict[str, np.ndarray],
|
| | speaker_keys: List[str],
|
| | ambiguous_indices: np.ndarray,
|
| | audio_length: int) -> np.ndarray:
|
| | """Assign speakers based on distance with late-audio bias."""
|
| | speaker_0_distances = distances[speaker_keys[0]]
|
| | speaker_1_distances = distances[speaker_keys[1]]
|
| |
|
| |
|
| | assignments = (speaker_1_distances < speaker_0_distances).astype(int)
|
| |
|
| |
|
| | late_threshold = int(audio_length * 0.6)
|
| | late_indices = ambiguous_indices > late_threshold
|
| |
|
| | if np.any(late_indices) and len(speaker_keys) > 1:
|
| |
|
| | assignments[late_indices] = 1
|
| |
|
| | return assignments
|
| |
|
| | def _save_outputs_optimized(self, waveform: torch.Tensor, masks: Dict[str, np.ndarray],
|
| | sample_rate: int, audio_path: str, output1, output2) -> Dict[str, str]:
|
| | """Optimized output saving with parallel processing."""
|
| | output_paths = {}
|
| |
|
| | def save_speaker_audio(speaker_mask_pair, output):
|
| | speaker, mask = speaker_mask_pair
|
| |
|
| | mask_tensor = torch.from_numpy(mask).unsqueeze(0)
|
| |
|
| |
|
| | masked_audio = waveform * mask_tensor
|
| |
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | torchaudio.save(output, masked_audio, sample_rate)
|
| |
|
| | xprint(f"✓ Saved {speaker}: {output}")
|
| | return speaker, output
|
| |
|
| |
|
| | with ThreadPoolExecutor(max_workers=2) as executor:
|
| | results = list(executor.map(save_speaker_audio, masks.items(), [output1, output2]))
|
| |
|
| | output_paths = dict(results)
|
| | return output_paths
|
| |
|
| | def print_summary(self, audio_path: str):
|
| | """xprint diarization summary."""
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | diarization = self.perform_optimized_diarization(audio_path)
|
| |
|
| | xprint("\n=== Diarization Summary ===")
|
| | for turn, _, speaker in diarization.itertracks(yield_label=True):
|
| | xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s")
|
| |
|
| | def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None):
|
| | global verbose_output
|
| | verbose_output = verbose
|
| | separator = OptimizedPyannote31SpeakerSeparator(
|
| | None,
|
| | None,
|
| | vad_onset=0.2,
|
| | vad_offset=0.8
|
| | )
|
| |
|
| | import time
|
| | start_time = time.time()
|
| |
|
| | outputs = separator.separate_audio(audio, output1, output2, audio_original)
|
| |
|
| | elapsed_time = time.time() - start_time
|
| | xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")
|
| | for speaker, path in outputs.items():
|
| | xprint(f"{speaker}: {path}")
|
| |
|
| | def main():
|
| |
|
| | parser = argparse.ArgumentParser(description="Optimized Pyannote 3.1 Speaker Separator")
|
| | parser.add_argument("--audio", required=True, help="Input audio file")
|
| | parser.add_argument("--output", required=True, help="Output directory")
|
| | parser.add_argument("--token", help="Hugging Face token")
|
| | parser.add_argument("--local-model", help="Path to local 3.1 model")
|
| | parser.add_argument("--summary", action="store_true", help="xprint summary")
|
| |
|
| |
|
| | parser.add_argument("--vad-onset", type=float, default=0.2,
|
| | help="VAD onset threshold (lower = more sensitive to speech start, default: 0.2)")
|
| | parser.add_argument("--vad-offset", type=float, default=0.8,
|
| | help="VAD offset threshold (higher = keeps speech longer, default: 0.8)")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | xprint("=== Optimized Pyannote 3.1 Speaker Separator ===")
|
| | xprint("Performance optimizations: vectorized operations, memory management, parallel processing")
|
| | xprint(f"Audio: {args.audio}")
|
| | xprint(f"Output: {args.output}")
|
| | xprint(f"VAD onset: {args.vad_onset}")
|
| | xprint(f"VAD offset: {args.vad_offset}")
|
| | xprint()
|
| |
|
| | if not os.path.exists(args.audio):
|
| | xprint(f"ERROR: Audio file not found: {args.audio}")
|
| | return
|
| |
|
| | try:
|
| |
|
| | separator = OptimizedPyannote31SpeakerSeparator(
|
| | args.token,
|
| | args.local_model,
|
| | vad_onset=args.vad_onset,
|
| | vad_offset=args.vad_offset
|
| | )
|
| |
|
| |
|
| | if args.summary:
|
| | separator.print_summary(args.audio)
|
| |
|
| |
|
| | import time
|
| | start_time = time.time()
|
| |
|
| | audio_name = Path(args.audio).stem
|
| | output_filename = f"{audio_name}_speaker0.wav"
|
| | output_filename1 = f"{audio_name}_speaker1.wav"
|
| | output_path = os.path.join(args.output, output_filename)
|
| | output_path1 = os.path.join(args.output, output_filename1)
|
| |
|
| | outputs = separator.separate_audio(args.audio, output_path, output_path1)
|
| |
|
| | elapsed_time = time.time() - start_time
|
| | xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")
|
| | for speaker, path in outputs.items():
|
| | xprint(f"{speaker}: {path}")
|
| |
|
| | except Exception as e:
|
| | xprint(f"ERROR: {e}")
|
| | return 1
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | exit(main()) |