import numpy as np
import torch
import time
import threading
import os
import queue
import torchaudio
from scipy.spatial.distance import cosine
from scipy.signal import resample
import logging
import urllib.request
# Import RealtimeSTT for transcription
from RealtimeSTT import AudioToTextRecorder
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Simplified configuration parameters
SILENCE_THRESHS = [0, 0.4]
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
FINAL_BEAM_SIZE = 5
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
REALTIME_BEAM_SIZE = 5
TRANSCRIPTION_LANGUAGE = "en"
SILERO_SENSITIVITY = 0.4
WEBRTC_SENSITIVITY = 3
MIN_LENGTH_OF_RECORDING = 0.7
PRE_RECORDING_BUFFER_DURATION = 0.35
# Speaker change detection parameters
DEFAULT_CHANGE_THRESHOLD = 0.65
EMBEDDING_HISTORY_SIZE = 5
MIN_SEGMENT_DURATION = 1.5
DEFAULT_MAX_SPEAKERS = 4
ABSOLUTE_MAX_SPEAKERS = 8
# Global variables
SAMPLE_RATE = 16000
BUFFER_SIZE = 1024
CHANNELS = 1
# Speaker colors - more distinguishable colors
SPEAKER_COLORS = [
"#FF6B6B", # Red
"#4ECDC4", # Teal
"#45B7D1", # Blue
"#96CEB4", # Green
"#FFEAA7", # Yellow
"#DDA0DD", # Plum
"#98D8C8", # Mint
"#F7DC6F", # Gold
]
SPEAKER_COLOR_NAMES = [
"Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold"
]
class SpeechBrainEncoder:
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
def __init__(self, device="cpu"):
self.device = device
self.model = None
self.embedding_dim = 192
self.model_loaded = False
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
os.makedirs(self.cache_dir, exist_ok=True)
def _download_model(self):
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
if not os.path.exists(model_path):
print(f"Downloading ECAPA-TDNN model to {model_path}...")
urllib.request.urlretrieve(model_url, model_path)
return model_path
def load_model(self):
"""Load the ECAPA-TDNN model"""
try:
# Import SpeechBrain
from speechbrain.pretrained import EncoderClassifier
# Get model path
model_path = self._download_model()
# Load the pre-trained model
self.model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=self.cache_dir,
run_opts={"device": self.device}
)
self.model_loaded = True
return True
except Exception as e:
print(f"Error loading ECAPA-TDNN model: {e}")
return False
def embed_utterance(self, audio, sr=16000):
"""Extract speaker embedding from audio"""
if not self.model_loaded:
raise ValueError("Model not loaded. Call load_model() first.")
try:
if isinstance(audio, np.ndarray):
# Ensure audio is float32 and properly normalized
audio = audio.astype(np.float32)
if np.max(np.abs(audio)) > 1.0:
audio = audio / np.max(np.abs(audio))
waveform = torch.tensor(audio).unsqueeze(0)
else:
waveform = audio.unsqueeze(0)
# Resample if necessary
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
with torch.no_grad():
embedding = self.model.encode_batch(waveform)
return embedding.squeeze().cpu().numpy()
except Exception as e:
logger.error(f"Error extracting embedding: {e}")
return np.zeros(self.embedding_dim)
class AudioProcessor:
"""Processes audio data to extract speaker embeddings"""
def __init__(self, encoder):
self.encoder = encoder
self.audio_buffer = []
self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio
def add_audio_chunk(self, audio_chunk):
"""Add audio chunk to buffer"""
self.audio_buffer.extend(audio_chunk)
# Keep buffer from getting too large
max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max
if len(self.audio_buffer) > max_buffer_size:
self.audio_buffer = self.audio_buffer[-max_buffer_size:]
def extract_embedding_from_buffer(self):
"""Extract embedding from current audio buffer"""
if len(self.audio_buffer) < self.min_audio_length:
return None
try:
# Use the last portion of the buffer for embedding
audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32)
# Normalize audio
if np.max(np.abs(audio_segment)) > 0:
audio_segment = audio_segment / np.max(np.abs(audio_segment))
else:
return None
embedding = self.encoder.embed_utterance(audio_segment)
return embedding
except Exception as e:
logger.error(f"Embedding extraction error: {e}")
return None
class SpeakerChangeDetector:
"""Improved speaker change detector"""
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
self.embedding_dim = embedding_dim
self.change_threshold = change_threshold
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
self.current_speaker = 0
self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
self.speaker_centroids = [None] * self.max_speakers
self.last_change_time = time.time()
self.last_similarity = 1.0
self.active_speakers = set([0])
self.segment_counter = 0
def set_max_speakers(self, max_speakers):
"""Update the maximum number of speakers"""
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
if new_max < self.max_speakers:
# Remove speakers beyond the new limit
for speaker_id in list(self.active_speakers):
if speaker_id >= new_max:
self.active_speakers.discard(speaker_id)
if self.current_speaker >= new_max:
self.current_speaker = 0
# Resize arrays
if new_max > self.max_speakers:
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
self.speaker_centroids.extend([None] * (new_max - self.max_speakers))
else:
self.speaker_embeddings = self.speaker_embeddings[:new_max]
self.speaker_centroids = self.speaker_centroids[:new_max]
self.max_speakers = new_max
def set_change_threshold(self, threshold):
"""Update the threshold for detecting speaker changes"""
self.change_threshold = max(0.1, min(threshold, 0.95))
def add_embedding(self, embedding, timestamp=None):
"""Add a new embedding and detect speaker changes"""
current_time = timestamp or time.time()
self.segment_counter += 1
# Initialize first speaker
if not self.speaker_embeddings[0]:
self.speaker_embeddings[0].append(embedding)
self.speaker_centroids[0] = embedding.copy()
self.active_speakers.add(0)
return 0, 1.0
# Calculate similarity with current speaker
current_centroid = self.speaker_centroids[self.current_speaker]
if current_centroid is not None:
similarity = 1.0 - cosine(embedding, current_centroid)
else:
similarity = 0.5
self.last_similarity = similarity
# Check for speaker change
time_since_last_change = current_time - self.last_change_time
speaker_changed = False
if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold:
# Find best matching speaker
best_speaker = self.current_speaker
best_similarity = similarity
for speaker_id in self.active_speakers:
if speaker_id == self.current_speaker:
continue
centroid = self.speaker_centroids[speaker_id]
if centroid is not None:
speaker_similarity = 1.0 - cosine(embedding, centroid)
if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold:
best_similarity = speaker_similarity
best_speaker = speaker_id
# If no good match found and we can add a new speaker
if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers:
for new_id in range(self.max_speakers):
if new_id not in self.active_speakers:
best_speaker = new_id
self.active_speakers.add(new_id)
break
if best_speaker != self.current_speaker:
self.current_speaker = best_speaker
self.last_change_time = current_time
speaker_changed = True
# Update speaker embeddings and centroids
self.speaker_embeddings[self.current_speaker].append(embedding)
# Keep only recent embeddings (sliding window)
max_embeddings = 20
if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings:
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:]
# Update centroid
if self.speaker_embeddings[self.current_speaker]:
self.speaker_centroids[self.current_speaker] = np.mean(
self.speaker_embeddings[self.current_speaker], axis=0
)
return self.current_speaker, similarity
def get_color_for_speaker(self, speaker_id):
"""Return color for speaker ID"""
if 0 <= speaker_id < len(SPEAKER_COLORS):
return SPEAKER_COLORS[speaker_id]
return "#FFFFFF"
def get_status_info(self):
"""Return status information"""
speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
return {
"current_speaker": self.current_speaker,
"speaker_counts": speaker_counts,
"active_speakers": len(self.active_speakers),
"max_speakers": self.max_speakers,
"last_similarity": self.last_similarity,
"threshold": self.change_threshold,
"segment_counter": self.segment_counter
}
class RealtimeSpeakerDiarization:
def __init__(self):
self.encoder = None
self.audio_processor = None
self.speaker_detector = None
self.recorder = None # RealtimeSTT recorder
self.sentence_queue = queue.Queue()
self.full_sentences = []
self.sentence_speakers = []
self.pending_sentences = []
self.current_conversation = ""
self.is_running = False
self.change_threshold = DEFAULT_CHANGE_THRESHOLD
self.max_speakers = DEFAULT_MAX_SPEAKERS
self.last_transcription = ""
self.transcription_lock = threading.Lock()
def initialize_models(self):
"""Initialize the speaker encoder model"""
try:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device_str}")
self.encoder = SpeechBrainEncoder(device=device_str)
success = self.encoder.load_model()
if success:
self.audio_processor = AudioProcessor(self.encoder)
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
# Initialize RealtimeSTT transcription model
self.recorder = AudioToTextRecorder(
spinner=False,
use_microphone=False,
model=FINAL_TRANSCRIPTION_MODEL,
language=TRANSCRIPTION_LANGUAGE,
silero_sensitivity=SILERO_SENSITIVITY,
webrtc_sensitivity=WEBRTC_SENSITIVITY,
post_speech_silence_duration=0.7,
min_length_of_recording=MIN_LENGTH_OF_RECORDING,
pre_recording_buffer_duration=PRE_RECORDING_BUFFER_DURATION,
enable_realtime_transcription=True,
realtime_processing_pause=0.2,
realtime_model_type=REALTIME_TRANSCRIPTION_MODEL,
on_realtime_transcription_update=self.live_text_detected,
on_recording_stop=self.process_final_text,
level=logging.WARNING,
# Don't start processing immediately
handle_buffer_overflow=True
)
logger.info("Models initialized successfully!")
return True
else:
logger.error("Failed to load models")
return False
except Exception as e:
logger.error(f"Model initialization error: {e}")
return False
def live_text_detected(self, text):
"""Callback for real-time transcription updates"""
with self.transcription_lock:
self.last_transcription = text.strip()
def process_final_text(self, text):
"""Process final transcribed text with speaker embedding"""
text = text.strip()
if text:
try:
# Get audio data for this transcription
audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None)
if audio_bytes:
self.sentence_queue.put((text, audio_bytes))
else:
# If no audio bytes, use current speaker
self.sentence_queue.put((text, None))
except Exception as e:
logger.error(f"Error processing final text: {e}")
def process_sentence_queue(self):
"""Process sentences in the queue for speaker detection"""
while self.is_running:
try:
text, audio_bytes = self.sentence_queue.get(timeout=1)
current_speaker = self.speaker_detector.current_speaker
if audio_bytes:
# Convert audio data and extract embedding
audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
audio_float = audio_int16.astype(np.float32) / 32768.0
# Extract embedding
embedding = self.audio_processor.encoder.embed_utterance(audio_float)
if embedding is not None:
current_speaker, similarity = self.speaker_detector.add_embedding(embedding)
# Store sentence with speaker
with self.transcription_lock:
self.full_sentences.append((text, current_speaker))
self.update_conversation_display()
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error processing sentence: {e}")
def update_conversation_display(self):
"""Update the conversation display"""
try:
sentences_with_style = []
for sentence_text, speaker_id in self.full_sentences:
color = self.speaker_detector.get_color_for_speaker(speaker_id)
speaker_name = f"Speaker {speaker_id + 1}"
sentences_with_style.append(
f'{speaker_name}: '
f'{sentence_text}'
)
# Add current transcription if available
if self.last_transcription:
current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker)
current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}"
sentences_with_style.append(
f'{current_speaker}: '
f'{self.last_transcription}...'
)
if sentences_with_style:
self.current_conversation = "
".join(sentences_with_style)
else:
self.current_conversation = "Waiting for speech input..."
except Exception as e:
logger.error(f"Error updating conversation display: {e}")
self.current_conversation = f"Error: {str(e)}"
def start_recording(self):
"""Start the recording and transcription process"""
if self.encoder is None:
return "Please initialize models first!"
try:
# Setup audio processor for speaker embeddings
self.is_running = True
# Start processing threads
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
self.sentence_thread.start()
# Start the RealtimeSTT recorder explicitly
if self.recorder:
# First make sure it's stopped if it was running
try:
if getattr(self.recorder, '_is_running', False):
self.recorder.stop()
except Exception:
pass
# Then start it fresh
self.recorder.start()
logger.info("RealtimeSTT recorder started")
return "Recording started successfully!"
except Exception as e:
logger.error(f"Error starting recording: {e}")
return f"Error starting recording: {e}"
def stop_recording(self):
"""Stop the recording process"""
self.is_running = False
# Stop the RealtimeSTT recorder
if self.recorder:
try:
self.recorder.stop()
logger.info("RealtimeSTT recorder stopped")
# Reset the last transcription
with self.transcription_lock:
self.last_transcription = ""
except Exception as e:
logger.error(f"Error stopping recorder: {e}")
return "Recording stopped!"
def clear_conversation(self):
"""Clear all conversation data"""
with self.transcription_lock:
self.full_sentences = []
self.last_transcription = ""
self.current_conversation = "Conversation cleared!"
if self.speaker_detector:
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
return "Conversation cleared!"
def update_settings(self, threshold, max_speakers):
"""Update speaker detection settings"""
self.change_threshold = threshold
self.max_speakers = max_speakers
if self.speaker_detector:
self.speaker_detector.set_change_threshold(threshold)
self.speaker_detector.set_max_speakers(max_speakers)
return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
def get_formatted_conversation(self):
"""Get the formatted conversation with structured data"""
try:
# Create conversation HTML format as before
html_content = self.current_conversation
# Create structured data
structured_data = {
"html_content": html_content,
"sentences": [],
"current_transcript": self.last_transcription,
"current_speaker": self.speaker_detector.current_speaker if self.speaker_detector else 0
}
# Add sentence data
for sentence_text, speaker_id in self.full_sentences:
color = self.speaker_detector.get_color_for_speaker(speaker_id) if self.speaker_detector else "#FFFFFF"
structured_data["sentences"].append({
"text": sentence_text,
"speaker_id": speaker_id,
"speaker_name": f"Speaker {speaker_id + 1}",
"color": color
})
return html_content
except Exception as e:
logger.error(f"Error formatting conversation: {e}")
return f"Error formatting conversation: {str(e)}"
def get_status_info(self):
"""Get current status information as structured data"""
if not self.speaker_detector:
return {"error": "Speaker detector not initialized"}
try:
speaker_status = self.speaker_detector.get_status_info()
# Format speaker activity
speaker_activity = []
for i in range(speaker_status['max_speakers']):
color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
count = speaker_status['speaker_counts'][i]
active = count > 0
speaker_activity.append({
"id": i,
"name": f"Speaker {i+1}",
"color": SPEAKER_COLORS[i] if i < len(SPEAKER_COLORS) else "#FFFFFF",
"color_name": color_name,
"segment_count": count,
"active": active
})
# Create structured status object
status = {
"current_speaker": speaker_status['current_speaker'],
"current_speaker_name": f"Speaker {speaker_status['current_speaker'] + 1}",
"active_speakers_count": speaker_status['active_speakers'],
"max_speakers": speaker_status['max_speakers'],
"last_similarity": speaker_status['last_similarity'],
"change_threshold": speaker_status['threshold'],
"total_sentences": len(self.full_sentences),
"segments_processed": speaker_status['segment_counter'],
"speaker_activity": speaker_activity,
"timestamp": time.time()
}
# Also create a formatted text version for UI display
status_lines = [
f"**Current Speaker:** {status['current_speaker'] + 1}",
f"**Active Speakers:** {status['active_speakers_count']} of {status['max_speakers']}",
f"**Last Similarity:** {status['last_similarity']:.3f}",
f"**Change Threshold:** {status['change_threshold']:.2f}",
f"**Total Sentences:** {status['total_sentences']}",
f"**Segments Processed:** {status['segments_processed']}",
"",
"**Speaker Activity:**"
]
for speaker in status["speaker_activity"]:
active = "🟢" if speaker["active"] else "⚫"
status_lines.append(f"{active} Speaker {speaker['id']+1} ({speaker['color_name']}): {speaker['segment_count']} segments")
status["formatted_text"] = "\n".join(status_lines)
return status
except Exception as e:
error_msg = f"Error getting status: {e}"
logger.error(error_msg)
return {"error": error_msg, "formatted_text": error_msg}
def process_audio_chunk(self, audio_data, sample_rate=16000):
"""Process audio chunk from WebSocket input"""
if not self.is_running or self.audio_processor is None:
return {"status": "not_running"}
try:
# Convert bytes to numpy array if needed
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
# Ensure audio is float32
if isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
else:
audio_data = np.array(audio_data, dtype=np.float32)
# Ensure mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten()
# Check if audio has meaningful content (not just silence)
audio_level = np.abs(audio_data).mean()
is_silence = audio_level < 0.01 # Threshold for silence
# Skip processing for silent audio
if is_silence:
return {
"status": "silent",
"buffer_size": len(self.audio_processor.audio_buffer),
"speaker_id": self.speaker_detector.current_speaker,
"conversation_html": self.current_conversation
}
# Normalize if needed
if np.max(np.abs(audio_data)) > 1.0:
audio_data = audio_data / np.max(np.abs(audio_data))
# Add to audio processor buffer for speaker detection
self.audio_processor.add_audio_chunk(audio_data)
# Feed to RealtimeSTT for transcription
if self.recorder:
# Convert to int16 for RealtimeSTT
audio_int16 = (audio_data * 32768).astype(np.int16)
self.recorder.feed_audio(audio_int16.tobytes())
# Periodically extract embeddings for speaker detection
embedding = None
speaker_id = self.speaker_detector.current_speaker
similarity = 1.0
if len(self.audio_processor.audio_buffer) >= SAMPLE_RATE and (len(self.audio_processor.audio_buffer) - SAMPLE_RATE) % (SAMPLE_RATE // 2)==0:
embedding = self.audio_processor.extract_embedding_from_buffer()
if embedding is not None:
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
# Return processing result
return {
"status": "processed",
"buffer_size": len(self.audio_processor.audio_buffer),
"speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
"similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
"conversation_html": self.current_conversation
}
except Exception as e:
logger.error(f"Error processing audio chunk: {e}")
return {"status": "error", "message": str(e)}
def resample_audio(self, audio_bytes, from_rate, to_rate):
"""Resample audio to target sample rate"""
try:
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
num_samples = len(audio_np)
num_target_samples = int(num_samples * to_rate / from_rate)
resampled = resample(audio_np, num_target_samples)
return resampled.astype(np.int16).tobytes()
except Exception as e:
logger.error(f"Error resampling audio: {e}")
return audio_bytes