marcosremar2's picture
Add Whisper transcription to speaker diarization
6df750f
import gradio as gr
from pyannote.audio import Pipeline
import torch
import whisper
from huggingface_hub import login
import os
import traceback
# Login to Hugging Face if token is available
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
print("WARNING: HF_TOKEN environment variable not found. Please set it in the Space settings.")
diarization_pipeline = None
else:
try:
login(token=hf_token)
print("Successfully logged in to Hugging Face")
# Initialize the diarization pipeline
print("Loading pyannote/speaker-diarization-3.1 pipeline...")
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=hf_token
)
print("Diarization pipeline loaded successfully!")
# Send pipeline to GPU if available
if torch.cuda.is_available():
print("GPU detected, moving pipeline to GPU")
diarization_pipeline.to(torch.device("cuda"))
else:
print("No GPU detected, using CPU")
except Exception as e:
print(f"Error loading diarization pipeline: {e}")
print(f"Error type: {type(e).__name__}")
print("Traceback:")
traceback.print_exc()
diarization_pipeline = None
# Load Whisper model
try:
print("Loading Whisper model...")
whisper_model = whisper.load_model("base")
print("Whisper model loaded successfully!")
except Exception as e:
print(f"Error loading Whisper model: {e}")
whisper_model = None
def transcribe_with_diarization(audio_file):
"""Process audio file for both diarization and transcription"""
if diarization_pipeline is None:
return "❌ Diarization pipeline not loaded. Please ensure HF_TOKEN is set and you have access to pyannote/speaker-diarization-3.1."
if whisper_model is None:
return "❌ Whisper model not loaded."
if audio_file is None:
return "Please upload an audio file."
try:
print(f"Processing audio file: {audio_file}")
# Step 1: Transcribe with Whisper
print("Transcribing audio with Whisper...")
transcription_result = whisper_model.transcribe(audio_file, language="pt")
segments = transcription_result["segments"]
print(f"Transcription complete. Found {len(segments)} segments")
# Step 2: Diarize with pyannote
print("Performing speaker diarization...")
diarization = diarization_pipeline(audio_file)
print("Diarization complete")
# Step 3: Match transcription segments with speaker labels
results = []
for segment in segments:
start_time = segment['start']
end_time = segment['end']
text = segment['text'].strip()
# Find the speaker at this timestamp
speaker = None
for turn, _, label in diarization.itertracks(yield_label=True):
# Check if this segment overlaps with the speaker turn
if turn.start <= start_time <= turn.end or turn.start <= end_time <= turn.end:
speaker = label
break
if speaker:
results.append(f"[{speaker}] ({start_time:.1f}s - {end_time:.1f}s): {text}")
else:
results.append(f"[Unknown] ({start_time:.1f}s - {end_time:.1f}s): {text}")
if not results:
return "No transcription available."
# Add summary
speakers = set()
for turn, _, speaker in diarization.itertracks(yield_label=True):
speakers.add(speaker)
summary = f"Found {len(speakers)} speakers in the conversation.\n\n"
return summary + "\n".join(results)
except Exception as e:
error_msg = f"Error processing audio: {str(e)}"
print(error_msg)
traceback.print_exc()
return error_msg
# Create Gradio interface
demo = gr.Interface(
fn=transcribe_with_diarization,
inputs=gr.Audio(type="filepath", label="Upload Audio File"),
outputs=gr.Textbox(label="Transcription with Speaker Identification", lines=20),
title="Speaker Diarization + Transcription",
description="Upload an audio file to identify different speakers and transcribe what they said. Uses pyannote for speaker identification and Whisper for transcription.",
examples=[],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()