|
import gradio as gr |
|
from pyannote.audio import Pipeline |
|
import torch |
|
import whisper |
|
from huggingface_hub import login |
|
import os |
|
import traceback |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
print("WARNING: HF_TOKEN environment variable not found. Please set it in the Space settings.") |
|
diarization_pipeline = None |
|
else: |
|
try: |
|
login(token=hf_token) |
|
print("Successfully logged in to Hugging Face") |
|
|
|
|
|
print("Loading pyannote/speaker-diarization-3.1 pipeline...") |
|
diarization_pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token=hf_token |
|
) |
|
print("Diarization pipeline loaded successfully!") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
print("GPU detected, moving pipeline to GPU") |
|
diarization_pipeline.to(torch.device("cuda")) |
|
else: |
|
print("No GPU detected, using CPU") |
|
|
|
except Exception as e: |
|
print(f"Error loading diarization pipeline: {e}") |
|
print(f"Error type: {type(e).__name__}") |
|
print("Traceback:") |
|
traceback.print_exc() |
|
diarization_pipeline = None |
|
|
|
|
|
try: |
|
print("Loading Whisper model...") |
|
whisper_model = whisper.load_model("base") |
|
print("Whisper model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading Whisper model: {e}") |
|
whisper_model = None |
|
|
|
def transcribe_with_diarization(audio_file): |
|
"""Process audio file for both diarization and transcription""" |
|
if diarization_pipeline is None: |
|
return "β Diarization pipeline not loaded. Please ensure HF_TOKEN is set and you have access to pyannote/speaker-diarization-3.1." |
|
|
|
if whisper_model is None: |
|
return "β Whisper model not loaded." |
|
|
|
if audio_file is None: |
|
return "Please upload an audio file." |
|
|
|
try: |
|
print(f"Processing audio file: {audio_file}") |
|
|
|
|
|
print("Transcribing audio with Whisper...") |
|
transcription_result = whisper_model.transcribe(audio_file, language="pt") |
|
segments = transcription_result["segments"] |
|
print(f"Transcription complete. Found {len(segments)} segments") |
|
|
|
|
|
print("Performing speaker diarization...") |
|
diarization = diarization_pipeline(audio_file) |
|
print("Diarization complete") |
|
|
|
|
|
results = [] |
|
|
|
for segment in segments: |
|
start_time = segment['start'] |
|
end_time = segment['end'] |
|
text = segment['text'].strip() |
|
|
|
|
|
speaker = None |
|
for turn, _, label in diarization.itertracks(yield_label=True): |
|
|
|
if turn.start <= start_time <= turn.end or turn.start <= end_time <= turn.end: |
|
speaker = label |
|
break |
|
|
|
if speaker: |
|
results.append(f"[{speaker}] ({start_time:.1f}s - {end_time:.1f}s): {text}") |
|
else: |
|
results.append(f"[Unknown] ({start_time:.1f}s - {end_time:.1f}s): {text}") |
|
|
|
if not results: |
|
return "No transcription available." |
|
|
|
|
|
speakers = set() |
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
speakers.add(speaker) |
|
|
|
summary = f"Found {len(speakers)} speakers in the conversation.\n\n" |
|
return summary + "\n".join(results) |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing audio: {str(e)}" |
|
print(error_msg) |
|
traceback.print_exc() |
|
return error_msg |
|
|
|
|
|
demo = gr.Interface( |
|
fn=transcribe_with_diarization, |
|
inputs=gr.Audio(type="filepath", label="Upload Audio File"), |
|
outputs=gr.Textbox(label="Transcription with Speaker Identification", lines=20), |
|
title="Speaker Diarization + Transcription", |
|
description="Upload an audio file to identify different speakers and transcribe what they said. Uses pyannote for speaker identification and Whisper for transcription.", |
|
examples=[], |
|
cache_examples=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |