import gradio as gr import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline import numpy as np from pyannote.audio import Pipeline import os from dotenv import load_dotenv import plotly.graph_objects as go load_dotenv() # Check and set device device = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Model and pipeline setup model_id = "distil-whisper/distil-small.en" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, torch_dtype=torch_dtype, device=device, ) diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY") ) # returns diarization info such as segment start and end times, and speaker id def diarization_info(res): starts = [] ends = [] speakers = [] for segment, _, speaker in res.itertracks(yield_label=True): starts.append(segment.start) ends.append(segment.end) speakers.append(speaker) return starts, ends, speakers # plot diarization results on a graph def plot_diarization(starts, ends, speakers): fig = go.Figure() # Define a color map for different speakers num_speakers = len(set(speakers)) colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)] # Plot each segment with its speaker's color for start, end, speaker in zip(starts, ends, speakers): speaker_id = list(set(speakers)).index(speaker) fig.add_trace( go.Scatter( x=[start, end], y=[speaker_id, speaker_id], mode="lines", line=dict(color=colors[speaker_id], width=15), showlegend=False, ) ) fig.update_layout( title="Speaker Diarization", xaxis=dict(title="Time"), yaxis=dict(title="Speaker"), height=600, width=800, ) return fig def transcribe(sr, data): processed_data = np.array(data).astype(np.float32) / 32767.0 # results from the pipeline transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"] return transcription_res def transcribe_diarize(audio): sr, data = audio processed_data = np.array(data).astype(np.float32) / 32767.0 waveform_tensor = torch.tensor(processed_data[np.newaxis, :]) transcription_res = transcribe(sr, data) # results from the diarization pipeline diarization_res = diarization_pipeline( {"waveform": waveform_tensor, "sample_rate": sr} ) # Get diarization information starts, ends, speakers = diarization_info(diarization_res) # results from the transcription pipeline diarized_transcription = "" # Get transcription results for each speaker segment for start_time, end_time, speaker_id in zip(starts, ends, speakers): segment = data[int(start_time * sr) : int(end_time * sr)] diarized_transcription += f"{speaker_id} {round(start_time, 2)}:{round(end_time, 2)} \t {transcribe(sr, segment)}\n" # Plot diarization diarization_plot = plot_diarization(starts, ends, speakers) return transcription_res, diarized_transcription, diarization_plot # creating the gradio interface demo = gr.Interface( fn=transcribe_diarize, inputs=gr.Audio(sources=["upload", "microphone"]), outputs=[ gr.Textbox(lines=3, label="Text Transcription"), gr.Textbox(label="Diarized Transcription"), gr.Plot(label="Visualization"), ], examples=["sample1.wav"], title="Automatic Speech Recognition with Diarization 🗣️", description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️", ) if __name__ == "__main__": demo.launch()