| import torch |
| import numpy as np |
| import base64 |
| from torchaudio import functional as F |
| from transformers.pipelines.audio_utils import ffmpeg_read |
| from starlette.exceptions import HTTPException |
| import sys |
|
|
| import logging |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def preprocess_inputs(inputs, sampling_rate): |
| inputs = ffmpeg_read(inputs, sampling_rate) |
|
|
| if sampling_rate != 16000: |
| inputs = F.resample( |
| torch.from_numpy(inputs), sampling_rate, 16000 |
| ).numpy() |
|
|
| if len(inputs.shape) != 1: |
| logger.error(f"Diarization pipeline expects single channel audio, received {inputs.shape}") |
| raise HTTPException( |
| status_code=400, |
| detail=f"Diarization pipeline expects single channel audio, received {inputs.shape}" |
| ) |
|
|
| |
| diarizer_inputs = torch.from_numpy(inputs).float() |
| diarizer_inputs = diarizer_inputs.unsqueeze(0) |
|
|
| return inputs, diarizer_inputs |
|
|
|
|
| def diarize_audio(diarizer_inputs, diarization_pipeline, parameters): |
| diarization = diarization_pipeline( |
| {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate}, |
| num_speakers=parameters.num_speakers, |
| min_speakers=parameters.min_speakers, |
| max_speakers=parameters.max_speakers, |
| ) |
|
|
| segments = [] |
| for segment, track, label in diarization.itertracks(yield_label=True): |
| segments.append( |
| { |
| "segment": {"start": segment.start, "end": segment.end}, |
| "track": track, |
| "label": label, |
| } |
| ) |
|
|
| |
| new_segments = [] |
| prev_segment = cur_segment = segments[0] |
|
|
| for i in range(1, len(segments)): |
| cur_segment = segments[i] |
|
|
| if cur_segment["label"] != prev_segment["label"] and i < len(segments): |
| new_segments.append( |
| { |
| "segment": { |
| "start": prev_segment["segment"]["start"], |
| "end": cur_segment["segment"]["start"], |
| }, |
| "speaker": prev_segment["label"], |
| } |
| ) |
| prev_segment = segments[i] |
|
|
| new_segments.append( |
| { |
| "segment": { |
| "start": prev_segment["segment"]["start"], |
| "end": cur_segment["segment"]["end"], |
| }, |
| "speaker": prev_segment["label"], |
| } |
| ) |
|
|
| return new_segments, diarization |
|
|
|
|
| def extract_speaker_embeddings(diarization_pipeline, diarizer_inputs, diarization_result, sampling_rate=16000): |
| """ |
| Extract per-speaker embeddings from pyannote's internal embedding model. |
| |
| pyannote's SpeakerDiarization pipeline has an internal embedding model |
| (wespeaker-based, 512-dim) that we can access directly. We use the |
| diarization result to identify which audio regions belong to each speaker, |
| then extract embeddings for those regions. |
| """ |
| try: |
| |
| embedding_model = diarization_pipeline._embedding |
| device = next(embedding_model.parameters()).device |
| |
| |
| speaker_labels = set() |
| for segment, _, label in diarization_result.itertracks(yield_label=True): |
| speaker_labels.add(label) |
| |
| speaker_embeddings = {} |
| |
| for speaker in speaker_labels: |
| |
| speaker_segments = [] |
| total_seconds = 0.0 |
| for segment, _, label in diarization_result.itertracks(yield_label=True): |
| if label == speaker: |
| speaker_segments.append(segment) |
| total_seconds += segment.duration |
| |
| if total_seconds < 0.5: |
| logger.warning(f"Speaker {speaker} has only {total_seconds:.1f}s of audio, skipping embedding") |
| continue |
| |
| |
| segment_embeddings = [] |
| waveform = diarizer_inputs |
| |
| for seg in speaker_segments: |
| start_sample = int(seg.start * sampling_rate) |
| end_sample = int(seg.end * sampling_rate) |
| |
| if end_sample > waveform.shape[1]: |
| end_sample = waveform.shape[1] |
| if end_sample - start_sample < sampling_rate * 0.3: |
| continue |
| |
| chunk = waveform[:, start_sample:end_sample].to(device) |
| |
| with torch.no_grad(): |
| emb = embedding_model(chunk) |
| |
| |
| if emb.dim() > 1: |
| emb = emb.squeeze() |
| emb = emb / (torch.norm(emb) + 1e-8) |
| segment_embeddings.append(emb.cpu().numpy()) |
| |
| if len(segment_embeddings) == 0: |
| continue |
| |
| |
| centroid = np.mean(segment_embeddings, axis=0).astype(np.float32) |
| centroid = centroid / (np.linalg.norm(centroid) + 1e-8) |
| |
| |
| centroid_b64 = base64.b64encode(centroid.tobytes()).decode("utf-8") |
| |
| speaker_embeddings[speaker] = { |
| "embedding_b64": centroid_b64, |
| "embedding_dim": int(centroid.shape[0]), |
| "total_seconds": round(total_seconds, 2), |
| "num_segments": len(segment_embeddings), |
| } |
| |
| logger.info(f"Speaker {speaker}: {total_seconds:.1f}s, {len(segment_embeddings)} segments, dim={centroid.shape[0]}") |
| |
| return speaker_embeddings |
| |
| except Exception as e: |
| logger.error(f"Error extracting speaker embeddings: {str(e)}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| return {} |
|
|
|
|
| def match_speakers(speaker_embeddings, known_speakers): |
| """ |
| Match diarized speakers against known speaker profiles using cosine similarity. |
| |
| known_speakers: list of dicts with {slug, name, centroid_b64, samples?} |
| speaker_embeddings: dict from extract_speaker_embeddings |
| |
| Returns dict mapping SPEAKER_XX -> {matched_slug, matched_name, confidence, score} |
| """ |
| if not known_speakers or not speaker_embeddings: |
| return {} |
| |
| |
| known_profiles = [] |
| for ks in known_speakers: |
| try: |
| centroid_bytes = base64.b64decode(ks["centroid_b64"]) |
| centroid = np.frombuffer(centroid_bytes, dtype=np.float32) |
| |
| |
| samples = [] |
| if ks.get("samples"): |
| for s in ks["samples"]: |
| if s.get("embedding_b64"): |
| s_bytes = base64.b64decode(s["embedding_b64"]) |
| samples.append(np.frombuffer(s_bytes, dtype=np.float32)) |
| |
| known_profiles.append({ |
| "slug": ks["slug"], |
| "name": ks["name"], |
| "centroid": centroid, |
| "samples": samples, |
| }) |
| except Exception as e: |
| logger.warning(f"Could not decode profile for {ks.get('slug', '?')}: {e}") |
| continue |
| |
| if not known_profiles: |
| return {} |
| |
| matches = {} |
| |
| for spk_label, spk_data in speaker_embeddings.items(): |
| try: |
| query_bytes = base64.b64decode(spk_data["embedding_b64"]) |
| query = np.frombuffer(query_bytes, dtype=np.float32) |
| except Exception: |
| continue |
| |
| best_score = -1.0 |
| best_profile = None |
| |
| for profile in known_profiles: |
| |
| centroid_score = float(np.dot(query, profile["centroid"]) / |
| (np.linalg.norm(query) * np.linalg.norm(profile["centroid"]) + 1e-8)) |
| |
| |
| best_sample_score = centroid_score |
| for sample in profile["samples"]: |
| s_score = float(np.dot(query, sample) / |
| (np.linalg.norm(query) * np.linalg.norm(sample) + 1e-8)) |
| best_sample_score = max(best_sample_score, s_score) |
| |
| |
| final_score = max(centroid_score, best_sample_score) |
| |
| if final_score > best_score: |
| best_score = final_score |
| best_profile = profile |
| |
| if best_profile is None: |
| continue |
| |
| |
| if best_score >= 0.55: |
| confidence = "HIGH" |
| elif best_score >= 0.35: |
| confidence = "MEDIUM" |
| else: |
| confidence = "LOW" |
| |
| matches[spk_label] = { |
| "matched_slug": best_profile["slug"], |
| "matched_name": best_profile["name"], |
| "confidence": confidence, |
| "score": round(best_score, 4), |
| } |
| |
| logger.info(f"Speaker {spk_label} -> {best_profile['name']} ({confidence}, {best_score:.4f})") |
| |
| return matches |
|
|
|
|
| def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list: |
| end_timestamps = np.array( |
| [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript]) |
| segmented_preds = [] |
|
|
| for segment in new_segments: |
| end_time = segment["segment"]["end"] |
| upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
|
|
| if group_by_speaker: |
| segmented_preds.append( |
| { |
| "speaker": segment["speaker"], |
| "text": "".join( |
| [chunk["text"] for chunk in transcript[: upto_idx + 1]] |
| ), |
| "timestamp": ( |
| transcript[0]["timestamp"][0], |
| transcript[upto_idx]["timestamp"][1], |
| ), |
| } |
| ) |
| else: |
| for i in range(upto_idx + 1): |
| segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) |
|
|
| transcript = transcript[upto_idx + 1:] |
| end_timestamps = end_timestamps[upto_idx + 1:] |
|
|
| if len(end_timestamps) == 0: |
| break |
|
|
| return segmented_preds |
|
|
|
|
| def diarize(diarization_pipeline, file, parameters, asr_outputs): |
| """Original diarize function — backward compatible.""" |
| _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) |
|
|
| segments, _ = diarize_audio( |
| diarizer_inputs, |
| diarization_pipeline, |
| parameters |
| ) |
|
|
| return post_process_segments_and_transcripts( |
| segments, asr_outputs["chunks"], group_by_speaker=False |
| ) |
|
|
|
|
| def diarize_with_embeddings(diarization_pipeline, file, parameters, asr_outputs): |
| """ |
| Extended diarize that also extracts per-speaker embeddings and optionally |
| matches against known speaker profiles. |
| |
| Returns: (transcript, speaker_embeddings, speaker_matches) |
| """ |
| _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) |
|
|
| segments, diarization_result = diarize_audio( |
| diarizer_inputs, |
| diarization_pipeline, |
| parameters |
| ) |
|
|
| transcript = post_process_segments_and_transcripts( |
| segments, asr_outputs["chunks"], group_by_speaker=False |
| ) |
| |
| |
| speaker_embeddings = {} |
| if parameters.return_embeddings: |
| speaker_embeddings = extract_speaker_embeddings( |
| diarization_pipeline, diarizer_inputs, diarization_result, |
| sampling_rate=parameters.sampling_rate |
| ) |
| |
| |
| speaker_matches = {} |
| if parameters.known_speakers and speaker_embeddings: |
| speaker_matches = match_speakers(speaker_embeddings, parameters.known_speakers) |
| |
| return transcript, speaker_embeddings, speaker_matches |
|
|