import numpy as np import torch import time import threading import os import queue import torchaudio from scipy.spatial.distance import cosine from scipy.signal import resample import logging import urllib.request # Import RealtimeSTT for transcription from RealtimeSTT import AudioToTextRecorder # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Simplified configuration parameters SILENCE_THRESHS = [0, 0.4] FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" FINAL_BEAM_SIZE = 5 REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" REALTIME_BEAM_SIZE = 5 TRANSCRIPTION_LANGUAGE = "en" SILERO_SENSITIVITY = 0.4 WEBRTC_SENSITIVITY = 3 MIN_LENGTH_OF_RECORDING = 0.7 PRE_RECORDING_BUFFER_DURATION = 0.35 # Speaker change detection parameters DEFAULT_CHANGE_THRESHOLD = 0.65 EMBEDDING_HISTORY_SIZE = 5 MIN_SEGMENT_DURATION = 1.5 DEFAULT_MAX_SPEAKERS = 4 ABSOLUTE_MAX_SPEAKERS = 8 # Global variables SAMPLE_RATE = 16000 BUFFER_SIZE = 1024 CHANNELS = 1 # Speaker colors - more distinguishable colors SPEAKER_COLORS = [ "#FF6B6B", # Red "#4ECDC4", # Teal "#45B7D1", # Blue "#96CEB4", # Green "#FFEAA7", # Yellow "#DDA0DD", # Plum "#98D8C8", # Mint "#F7DC6F", # Gold ] SPEAKER_COLOR_NAMES = [ "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold" ] class SpeechBrainEncoder: """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" def __init__(self, device="cpu"): self.device = device self.model = None self.embedding_dim = 192 self.model_loaded = False self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") os.makedirs(self.cache_dir, exist_ok=True) def _download_model(self): """Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") if not os.path.exists(model_path): print(f"Downloading ECAPA-TDNN model to {model_path}...") urllib.request.urlretrieve(model_url, model_path) return model_path def load_model(self): """Load the ECAPA-TDNN model""" try: # Import SpeechBrain from speechbrain.pretrained import EncoderClassifier # Get model path model_path = self._download_model() # Load the pre-trained model self.model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir=self.cache_dir, run_opts={"device": self.device} ) self.model_loaded = True return True except Exception as e: print(f"Error loading ECAPA-TDNN model: {e}") return False def embed_utterance(self, audio, sr=16000): """Extract speaker embedding from audio""" if not self.model_loaded: raise ValueError("Model not loaded. Call load_model() first.") try: if isinstance(audio, np.ndarray): # Ensure audio is float32 and properly normalized audio = audio.astype(np.float32) if np.max(np.abs(audio)) > 1.0: audio = audio / np.max(np.abs(audio)) waveform = torch.tensor(audio).unsqueeze(0) else: waveform = audio.unsqueeze(0) # Resample if necessary if sr != 16000: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) with torch.no_grad(): embedding = self.model.encode_batch(waveform) return embedding.squeeze().cpu().numpy() except Exception as e: logger.error(f"Error extracting embedding: {e}") return np.zeros(self.embedding_dim) class AudioProcessor: """Processes audio data to extract speaker embeddings""" def __init__(self, encoder): self.encoder = encoder self.audio_buffer = [] self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio def add_audio_chunk(self, audio_chunk): """Add audio chunk to buffer""" self.audio_buffer.extend(audio_chunk) # Keep buffer from getting too large max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max if len(self.audio_buffer) > max_buffer_size: self.audio_buffer = self.audio_buffer[-max_buffer_size:] def extract_embedding_from_buffer(self): """Extract embedding from current audio buffer""" if len(self.audio_buffer) < self.min_audio_length: return None try: # Use the last portion of the buffer for embedding audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32) # Normalize audio if np.max(np.abs(audio_segment)) > 0: audio_segment = audio_segment / np.max(np.abs(audio_segment)) else: return None embedding = self.encoder.embed_utterance(audio_segment) return embedding except Exception as e: logger.error(f"Embedding extraction error: {e}") return None class SpeakerChangeDetector: """Improved speaker change detector""" def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): self.embedding_dim = embedding_dim self.change_threshold = change_threshold self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) self.current_speaker = 0 self.speaker_embeddings = [[] for _ in range(self.max_speakers)] self.speaker_centroids = [None] * self.max_speakers self.last_change_time = time.time() self.last_similarity = 1.0 self.active_speakers = set([0]) self.segment_counter = 0 def set_max_speakers(self, max_speakers): """Update the maximum number of speakers""" new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) if new_max < self.max_speakers: # Remove speakers beyond the new limit for speaker_id in list(self.active_speakers): if speaker_id >= new_max: self.active_speakers.discard(speaker_id) if self.current_speaker >= new_max: self.current_speaker = 0 # Resize arrays if new_max > self.max_speakers: self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) self.speaker_centroids.extend([None] * (new_max - self.max_speakers)) else: self.speaker_embeddings = self.speaker_embeddings[:new_max] self.speaker_centroids = self.speaker_centroids[:new_max] self.max_speakers = new_max def set_change_threshold(self, threshold): """Update the threshold for detecting speaker changes""" self.change_threshold = max(0.1, min(threshold, 0.95)) def add_embedding(self, embedding, timestamp=None): """Add a new embedding and detect speaker changes""" current_time = timestamp or time.time() self.segment_counter += 1 # Initialize first speaker if not self.speaker_embeddings[0]: self.speaker_embeddings[0].append(embedding) self.speaker_centroids[0] = embedding.copy() self.active_speakers.add(0) return 0, 1.0 # Calculate similarity with current speaker current_centroid = self.speaker_centroids[self.current_speaker] if current_centroid is not None: similarity = 1.0 - cosine(embedding, current_centroid) else: similarity = 0.5 self.last_similarity = similarity # Check for speaker change time_since_last_change = current_time - self.last_change_time speaker_changed = False if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold: # Find best matching speaker best_speaker = self.current_speaker best_similarity = similarity for speaker_id in self.active_speakers: if speaker_id == self.current_speaker: continue centroid = self.speaker_centroids[speaker_id] if centroid is not None: speaker_similarity = 1.0 - cosine(embedding, centroid) if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold: best_similarity = speaker_similarity best_speaker = speaker_id # If no good match found and we can add a new speaker if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers: for new_id in range(self.max_speakers): if new_id not in self.active_speakers: best_speaker = new_id self.active_speakers.add(new_id) break if best_speaker != self.current_speaker: self.current_speaker = best_speaker self.last_change_time = current_time speaker_changed = True # Update speaker embeddings and centroids self.speaker_embeddings[self.current_speaker].append(embedding) # Keep only recent embeddings (sliding window) max_embeddings = 20 if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings: self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:] # Update centroid if self.speaker_embeddings[self.current_speaker]: self.speaker_centroids[self.current_speaker] = np.mean( self.speaker_embeddings[self.current_speaker], axis=0 ) return self.current_speaker, similarity def get_color_for_speaker(self, speaker_id): """Return color for speaker ID""" if 0 <= speaker_id < len(SPEAKER_COLORS): return SPEAKER_COLORS[speaker_id] return "#FFFFFF" def get_status_info(self): """Return status information""" speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] return { "current_speaker": self.current_speaker, "speaker_counts": speaker_counts, "active_speakers": len(self.active_speakers), "max_speakers": self.max_speakers, "last_similarity": self.last_similarity, "threshold": self.change_threshold, "segment_counter": self.segment_counter } class RealtimeSpeakerDiarization: def __init__(self): self.encoder = None self.audio_processor = None self.speaker_detector = None self.recorder = None # RealtimeSTT recorder self.sentence_queue = queue.Queue() self.full_sentences = [] self.sentence_speakers = [] self.pending_sentences = [] self.current_conversation = "" self.is_running = False self.change_threshold = DEFAULT_CHANGE_THRESHOLD self.max_speakers = DEFAULT_MAX_SPEAKERS self.last_transcription = "" self.transcription_lock = threading.Lock() def initialize_models(self): """Initialize the speaker encoder model""" try: device_str = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device_str}") self.encoder = SpeechBrainEncoder(device=device_str) success = self.encoder.load_model() if success: self.audio_processor = AudioProcessor(self.encoder) self.speaker_detector = SpeakerChangeDetector( embedding_dim=self.encoder.embedding_dim, change_threshold=self.change_threshold, max_speakers=self.max_speakers ) # Initialize RealtimeSTT transcription model self.recorder = AudioToTextRecorder( spinner=False, use_microphone=False, model=FINAL_TRANSCRIPTION_MODEL, language=TRANSCRIPTION_LANGUAGE, silero_sensitivity=SILERO_SENSITIVITY, webrtc_sensitivity=WEBRTC_SENSITIVITY, post_speech_silence_duration=0.7, min_length_of_recording=MIN_LENGTH_OF_RECORDING, pre_recording_buffer_duration=PRE_RECORDING_BUFFER_DURATION, enable_realtime_transcription=True, realtime_processing_pause=0.2, realtime_model_type=REALTIME_TRANSCRIPTION_MODEL, on_realtime_transcription_update=self.live_text_detected, on_recording_stop=self.process_final_text, level=logging.WARNING, # Don't start processing immediately handle_buffer_overflow=True ) logger.info("Models initialized successfully!") return True else: logger.error("Failed to load models") return False except Exception as e: logger.error(f"Model initialization error: {e}") return False def live_text_detected(self, text): """Callback for real-time transcription updates""" with self.transcription_lock: self.last_transcription = text.strip() def process_final_text(self, text): """Process final transcribed text with speaker embedding""" text = text.strip() if text: try: # Get audio data for this transcription audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None) if audio_bytes: self.sentence_queue.put((text, audio_bytes)) else: # If no audio bytes, use current speaker self.sentence_queue.put((text, None)) except Exception as e: logger.error(f"Error processing final text: {e}") def process_sentence_queue(self): """Process sentences in the queue for speaker detection""" while self.is_running: try: text, audio_bytes = self.sentence_queue.get(timeout=1) current_speaker = self.speaker_detector.current_speaker if audio_bytes: # Convert audio data and extract embedding audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16) audio_float = audio_int16.astype(np.float32) / 32768.0 # Extract embedding embedding = self.audio_processor.encoder.embed_utterance(audio_float) if embedding is not None: current_speaker, similarity = self.speaker_detector.add_embedding(embedding) # Store sentence with speaker with self.transcription_lock: self.full_sentences.append((text, current_speaker)) self.update_conversation_display() except queue.Empty: continue except Exception as e: logger.error(f"Error processing sentence: {e}") def update_conversation_display(self): """Update the conversation display""" try: sentences_with_style = [] for sentence_text, speaker_id in self.full_sentences: color = self.speaker_detector.get_color_for_speaker(speaker_id) speaker_name = f"Speaker {speaker_id + 1}" sentences_with_style.append( f'{speaker_name}: ' f'{sentence_text}' ) # Add current transcription if available if self.last_transcription: current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker) current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}" sentences_with_style.append( f'{current_speaker}: ' f'{self.last_transcription}...' ) if sentences_with_style: self.current_conversation = "

".join(sentences_with_style) else: self.current_conversation = "Waiting for speech input..." except Exception as e: logger.error(f"Error updating conversation display: {e}") self.current_conversation = f"Error: {str(e)}" def start_recording(self): """Start the recording and transcription process""" if self.encoder is None: return "Please initialize models first!" try: # Setup audio processor for speaker embeddings self.is_running = True # Start processing threads self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True) self.sentence_thread.start() # Start the RealtimeSTT recorder explicitly if self.recorder: # First make sure it's stopped if it was running try: if getattr(self.recorder, '_is_running', False): self.recorder.stop() except Exception: pass # Then start it fresh self.recorder.start() logger.info("RealtimeSTT recorder started") return "Recording started successfully!" except Exception as e: logger.error(f"Error starting recording: {e}") return f"Error starting recording: {e}" def stop_recording(self): """Stop the recording process""" self.is_running = False # Stop the RealtimeSTT recorder if self.recorder: try: self.recorder.stop() logger.info("RealtimeSTT recorder stopped") # Reset the last transcription with self.transcription_lock: self.last_transcription = "" except Exception as e: logger.error(f"Error stopping recorder: {e}") return "Recording stopped!" def clear_conversation(self): """Clear all conversation data""" with self.transcription_lock: self.full_sentences = [] self.last_transcription = "" self.current_conversation = "Conversation cleared!" if self.speaker_detector: self.speaker_detector = SpeakerChangeDetector( embedding_dim=self.encoder.embedding_dim, change_threshold=self.change_threshold, max_speakers=self.max_speakers ) return "Conversation cleared!" def update_settings(self, threshold, max_speakers): """Update speaker detection settings""" self.change_threshold = threshold self.max_speakers = max_speakers if self.speaker_detector: self.speaker_detector.set_change_threshold(threshold) self.speaker_detector.set_max_speakers(max_speakers) return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}" def get_formatted_conversation(self): """Get the formatted conversation with structured data""" try: # Create conversation HTML format as before html_content = self.current_conversation # Create structured data structured_data = { "html_content": html_content, "sentences": [], "current_transcript": self.last_transcription, "current_speaker": self.speaker_detector.current_speaker if self.speaker_detector else 0 } # Add sentence data for sentence_text, speaker_id in self.full_sentences: color = self.speaker_detector.get_color_for_speaker(speaker_id) if self.speaker_detector else "#FFFFFF" structured_data["sentences"].append({ "text": sentence_text, "speaker_id": speaker_id, "speaker_name": f"Speaker {speaker_id + 1}", "color": color }) return html_content except Exception as e: logger.error(f"Error formatting conversation: {e}") return f"Error formatting conversation: {str(e)}" def get_status_info(self): """Get current status information as structured data""" if not self.speaker_detector: return {"error": "Speaker detector not initialized"} try: speaker_status = self.speaker_detector.get_status_info() # Format speaker activity speaker_activity = [] for i in range(speaker_status['max_speakers']): color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}" count = speaker_status['speaker_counts'][i] active = count > 0 speaker_activity.append({ "id": i, "name": f"Speaker {i+1}", "color": SPEAKER_COLORS[i] if i < len(SPEAKER_COLORS) else "#FFFFFF", "color_name": color_name, "segment_count": count, "active": active }) # Create structured status object status = { "current_speaker": speaker_status['current_speaker'], "current_speaker_name": f"Speaker {speaker_status['current_speaker'] + 1}", "active_speakers_count": speaker_status['active_speakers'], "max_speakers": speaker_status['max_speakers'], "last_similarity": speaker_status['last_similarity'], "change_threshold": speaker_status['threshold'], "total_sentences": len(self.full_sentences), "segments_processed": speaker_status['segment_counter'], "speaker_activity": speaker_activity, "timestamp": time.time() } # Also create a formatted text version for UI display status_lines = [ f"**Current Speaker:** {status['current_speaker'] + 1}", f"**Active Speakers:** {status['active_speakers_count']} of {status['max_speakers']}", f"**Last Similarity:** {status['last_similarity']:.3f}", f"**Change Threshold:** {status['change_threshold']:.2f}", f"**Total Sentences:** {status['total_sentences']}", f"**Segments Processed:** {status['segments_processed']}", "", "**Speaker Activity:**" ] for speaker in status["speaker_activity"]: active = "🟢" if speaker["active"] else "⚫" status_lines.append(f"{active} Speaker {speaker['id']+1} ({speaker['color_name']}): {speaker['segment_count']} segments") status["formatted_text"] = "\n".join(status_lines) return status except Exception as e: error_msg = f"Error getting status: {e}" logger.error(error_msg) return {"error": error_msg, "formatted_text": error_msg} def process_audio_chunk(self, audio_data, sample_rate=16000): """Process audio chunk from WebSocket input""" if not self.is_running or self.audio_processor is None: return {"status": "not_running"} try: # Convert bytes to numpy array if needed if isinstance(audio_data, bytes): audio_data = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 # Ensure audio is float32 if isinstance(audio_data, np.ndarray): if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) else: audio_data = np.array(audio_data, dtype=np.float32) # Ensure mono if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten() # Check if audio has meaningful content (not just silence) audio_level = np.abs(audio_data).mean() is_silence = audio_level < 0.01 # Threshold for silence # Skip processing for silent audio if is_silence: return { "status": "silent", "buffer_size": len(self.audio_processor.audio_buffer), "speaker_id": self.speaker_detector.current_speaker, "conversation_html": self.current_conversation } # Normalize if needed if np.max(np.abs(audio_data)) > 1.0: audio_data = audio_data / np.max(np.abs(audio_data)) # Add to audio processor buffer for speaker detection self.audio_processor.add_audio_chunk(audio_data) # Feed to RealtimeSTT for transcription if self.recorder: # Convert to int16 for RealtimeSTT audio_int16 = (audio_data * 32768).astype(np.int16) self.recorder.feed_audio(audio_int16.tobytes()) # Periodically extract embeddings for speaker detection embedding = None speaker_id = self.speaker_detector.current_speaker similarity = 1.0 if len(self.audio_processor.audio_buffer) >= SAMPLE_RATE and (len(self.audio_processor.audio_buffer) - SAMPLE_RATE) % (SAMPLE_RATE // 2)==0: embedding = self.audio_processor.extract_embedding_from_buffer() if embedding is not None: speaker_id, similarity = self.speaker_detector.add_embedding(embedding) # Return processing result return { "status": "processed", "buffer_size": len(self.audio_processor.audio_buffer), "speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id, "similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity, "conversation_html": self.current_conversation } except Exception as e: logger.error(f"Error processing audio chunk: {e}") return {"status": "error", "message": str(e)} def resample_audio(self, audio_bytes, from_rate, to_rate): """Resample audio to target sample rate""" try: audio_np = np.frombuffer(audio_bytes, dtype=np.int16) num_samples = len(audio_np) num_target_samples = int(num_samples * to_rate / from_rate) resampled = resample(audio_np, num_target_samples) return resampled.astype(np.int16).tobytes() except Exception as e: logger.error(f"Error resampling audio: {e}") return audio_bytes