import streamlit as st import torch import torchaudio from pyannote.audio import Pipeline from pyannote.audio.pipelines.utils.hook import ProgressHook import tempfile import os import matplotlib.pyplot as plt from pyannote.core import notebook from huggingface_hub import HfApi, snapshot_download, hf_hub_download from huggingface_hub.errors import LocalEntryNotFoundError, HfHubHTTPError import requests import pyannote.audio import sys import traceback from speechbrain.pretrained import EncoderClassifier from pydub import AudioSegment import numpy as np # Set page configuration st.set_page_config(page_title="Optimized Speaker Diarization App", layout="wide") st.title("Optimized Speaker Diarization App") # Fetch HF_TOKEN from environment variable HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: st.error("HF_TOKEN not found in environment variables. Please set it in your Hugging Face Space secrets.") st.stop() class ProgressHook: def __init__(self, status, progress_bar): self.status = status self.progress_bar = progress_bar self.total = 0 self.completed = 0 self.current_stage = "" def __call__(self, *args, **kwargs): if len(args) == 2 and isinstance(args[0], str): # Handle the case where it's called with (stage, data) self.current_stage = args[0] self.status.update(label=f"Processing: {self.current_stage}", state="running") elif 'completed' in kwargs and 'total' in kwargs: self.completed = kwargs['completed'] self.total = kwargs['total'] self._update_progress() elif len(args) == 2 and all(isinstance(arg, (int, float)) for arg in args): self.completed, self.total = args self._update_progress() def _update_progress(self): if self.total > 0: progress_percentage = min(self.completed / self.total, 1.0) self.status.update(label=f"Processing: {self.current_stage} - {progress_percentage:.1%} complete", state="running") self.progress_bar.progress(progress_percentage) def preprocess_audio(tmp_path): # Load the audio file using pydub audio = AudioSegment.from_file(tmp_path) # Convert to mono if stereo if audio.channels == 2: audio = audio.set_channels(1) # Resample to 16kHz if necessary if audio.frame_rate != 16000: audio = audio.set_frame_rate(16000) st.info("Resampled audio to 16 kHz") # Convert to numpy array samples = np.array(audio.get_array_of_samples()) # Convert to torch tensor waveform = torch.FloatTensor(samples).unsqueeze(0) / 32768.0 # Normalize to [-1, 1] # Determine the segment size (10 seconds at 16 kHz) segment_size = 160000 # Calculate the number of segments num_segments = (waveform.shape[1] + segment_size - 1) // segment_size # Calculate the expected total length expected_length = num_segments * segment_size # Calculate the padding length padding_length = expected_length - waveform.shape[1] if padding_length > 0: # Pad the waveform with zeros pad = torch.zeros((waveform.shape[0], padding_length)) waveform = torch.cat((waveform, pad), dim=1) st.info(f"Padded waveform with {padding_length} zeros") else: st.info("No padding needed") # Save the processed waveform to a temporary WAV file with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as processed_file: processed_path = processed_file.name torchaudio.save(processed_path, waveform, 16000) st.info("Saved processed waveform to temporary WAV file") return waveform, 16000, processed_path def check_versions(): st.info("Checking package versions...") pyannote_version = pyannote.audio.__version__ torch_version = torch.__version__ st.write(f"Pyannote Audio version: {pyannote_version}") st.write(f"PyTorch version: {torch_version}") if pyannote_version < "3.1.0": st.warning("Your pyannote.audio version might be outdated. Consider upgrading to 3.1.0 or later.") if torch_version < "2.0.0": st.warning("Your PyTorch version might be outdated. Consider upgrading to 2.0.0 or later.") check_versions() def verify_token(token): api = HfApi() try: user_info = api.whoami(token=token) st.success(f"Token verified. Logged in as: {user_info['name']}") return True except Exception as e: st.error(f"Token verification failed: {str(e)}") return False def check_hf_api(): st.info("Checking Hugging Face API...") api_url = "https://huggingface.co/api/models/pyannote/speaker-diarization-3.1" headers = {"Authorization": f"Bearer {HF_TOKEN}"} try: response = requests.get(api_url, headers=headers) response.raise_for_status() st.success("Successfully connected to Hugging Face API") with st.expander("API Response"): st.json(response.json()) except requests.exceptions.RequestException as e: st.error(f"Error connecting to Hugging Face API: {str(e)}") if response.status_code == 403: st.error("Access denied. Please check your token permissions.") st.info("Ensure your token has permission to access gated repositories.") st.code(response.text) def verify_model_files(): st.info("Verifying model files...") required_files = [ "config.yaml", "pytorch_model.bin", "pyannote_serialized_object.bin" ] for file in required_files: try: path = hf_hub_download("pyannote/speaker-diarization-3.1", filename=file, use_auth_token=HF_TOKEN) if os.path.exists(path): st.success(f"File {file} found at {path}") else: st.error(f"File {file} not found") except Exception as e: st.error(f"Error downloading {file}: {str(e)}") @st.cache_resource def load_pipeline(): try: st.info("Attempting to load the pipeline...") pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN ) st.success("Pipeline created successfully") if torch.cuda.is_available(): st.info("Moving pipeline to GPU...") pipeline.to(torch.device("cuda")) st.success("Pipeline moved to GPU") return pipeline except Exception as e: st.error(f"Error loading pipeline: {str(e)}") st.error("Error details:") st.code(traceback.format_exc()) raise e @st.cache_resource def load_speechbrain_model(): st.info("Loading SpeechBrain model...") classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb") st.success("SpeechBrain model loaded successfully") return classifier # Sidebar with st.sidebar: st.header("Settings") show_advanced = st.toggle("Show Advanced Options") if show_advanced: num_speakers = st.number_input("Number of speakers (0 for auto)", min_value=0, value=0) min_speakers = st.number_input("Minimum number of speakers", min_value=1, value=1) max_speakers = st.number_input("Maximum number of speakers", min_value=1, value=5) # Main content tab1, tab2, tab3 = st.tabs(["Upload & Process", "Results", "Visualization"]) with tab1: uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac']) if uploaded_file is not None: # Save uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: tmp_file.write(uploaded_file.getvalue()) tmp_path = tmp_file.name try: if verify_token(HF_TOKEN): check_hf_api() verify_model_files() pipeline = load_pipeline() speechbrain_model = load_speechbrain_model() else: st.stop() # Preprocess the audio file waveform, sample_rate, processed_path = preprocess_audio(tmp_path) with st.status("Processing audio...", expanded=True) as status: progress_bar = st.progress(0) progress_hook = ProgressHook(status, progress_bar) # Run the pipeline on the processed audio file diarization_args = { "file": processed_path, "hook": progress_hook } if show_advanced: if num_speakers > 0: diarization_args["num_speakers"] = num_speakers else: diarization_args["min_speakers"] = min_speakers diarization_args["max_speakers"] = max_speakers diarization = pipeline(**diarization_args) status.update(label="Diarization complete!", state="complete") # Generate RTTM content rttm_content = "" for turn, _, speaker in diarization.itertracks(yield_label=True): rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} {speaker} \n" rttm_content += rttm_line # Use SpeechBrain for speaker embedding (optional) embeddings = speechbrain_model.encode_batch(waveform) st.success("Speaker embeddings generated successfully") except Exception as e: st.error(f"An error occurred: {str(e)}") st.error("Error details:") st.code(traceback.format_exc()) finally: # Clean up the temporary files os.unlink(tmp_path) if 'processed_path' in locals(): os.unlink(processed_path) with tab2: if 'diarization' in locals(): st.subheader("Diarization Results") st.metric("Number of speakers detected", len(diarization.labels())) with st.expander("RTTM Output"): st.text_area("RTTM Content", rttm_content, height=300) st.download_button( label="Download RTTM file", data=rttm_content, file_name="diarization.rttm", mime="text/plain" ) with tab3: if 'diarization' in locals(): if st.button("Visualize Diarization"): fig, ax = plt.subplots(figsize=(10, 2)) notebook.plot_diarization(diarization, ax=ax) plt.tight_layout() st.pyplot(fig) # Debug Information with st.expander("Debug Information"): st.write(f"Working directory: {os.getcwd()}") st.write(f"Files in working directory: {os.listdir()}") st.write(f"Python version: {sys.version.split()[0]}") st.write(f"PyTorch version: {torch.__version__}") st.write(f"Pyannote Audio version: {pyannote.audio.__version__}") st.write(f"CUDA available: {torch.cuda.is_available()}") st.write(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}") # Token Permissions Instructions with st.expander("Token Permissions"): st.markdown(""" If you're encountering access issues, please ensure your Hugging Face token has the following permissions: 1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) 2. Find your token or create a new one 3. Ensure "Read" access is granted 4. Check the box for "Access to gated repositories" 5. Save the changes and try again """) # Clear Cache Button if st.button("Clear Cache"): import shutil cache_dir = "./model_cache" if os.path.exists(cache_dir): shutil.rmtree(cache_dir) st.success("Cache cleared successfully.") else: st.info("No cache directory found.")