import gradio as gr import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline import numpy as np from pyannote.audio import Pipeline import os from dotenv import load_dotenv import plotly.graph_objects as go from utils.diarize_utils import match_segments load_dotenv() # Check and set device device = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Model and pipeline setup model_id = "openai/whisper-large-v3" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) processor.tokenizer.set_prefix_tokens(predict_timestamps=True) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, torch_dtype=torch_dtype, device=device, ) diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY") ) # returns diarization info such as segment start and end times, and speaker id def diarization_info(res): starts = [] ends = [] speakers = [] for segment, _, speaker in res.itertracks(yield_label=True): starts.append(segment.start) ends.append(segment.end) speakers.append(speaker) return starts, ends, speakers # plot diarization results on a graph def plot_diarization(starts, ends, speakers): fig = go.Figure() # Define a color map for different speakers num_speakers = len(set(speakers)) colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)] # Plot each segment with its speaker's color for start, end, speaker in zip(starts, ends, speakers): speaker_id = list(set(speakers)).index(speaker) fig.add_trace( go.Scatter( x=[start, end], y=[speaker_id, speaker_id], mode="lines", line=dict(color=colors[speaker_id], width=15), showlegend=False, ) ) fig.update_layout( title="Speaker Diarization", xaxis=dict(title="Time"), yaxis=dict(title="Speaker"), height=600, width=800, ) return fig def transcribe(sr, data): processed_data = np.array(data).astype(np.float32) / 32767.0 # results from the pipeline transcription_res = pipe({"sampling_rate": sr, "raw": processed_data},return_timestamps=True) return transcription_res def format_string(timestamp): """ Reformat a timestamp string from (HH:)MM:SS to float seconds. Note that the hour column is optional, and is appended within the function if not input. Args: timestamp (str): Timestamp in string format, either MM:SS or HH:MM:SS. Returns: seconds (float): Total seconds corresponding to the input timestamp. """ split_time = timestamp.split(":") split_time = [float(sub_time) for sub_time in split_time] if len(split_time) == 2: split_time.insert(0, 0) seconds = split_time[0] * 3600 + split_time[1] * 60 + split_time[2] return seconds # Adapted from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50 def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."): """ Reformat a timestamp from a float of seconds to a string in format (HH:)MM:SS. Note that the hour column is optional, and is appended in the function if the number of hours > 0. Args: seconds (float): Total seconds corresponding to the input timestamp. Returns: timestamp (str): Timestamp in string format, either MM:SS or HH:MM:SS. """ if seconds is not None: milliseconds = round(seconds * 1000.0) hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 seconds = milliseconds // 1_000 milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" else: # we have a malformed timestamp so just return it as is return seconds def format_as_transcription(raw_segments): return "\n\n".join( [ f"{chunk['speaker']} [{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" for chunk in raw_segments ] ) def merge_segments(segments): merged_segments = [] current_segment = segments[0] for segment in segments[1:]: if segment['speaker'] == current_segment['speaker'] and current_segment['timestamp'][1] == segment['timestamp'][0]: current_segment['timestamp'] = (current_segment['timestamp'][0], segment['timestamp'][1]) current_segment['text'] += ' ' + segment['text'] else: merged_segments.append(current_segment) current_segment = segment merged_segments.append(current_segment) return format_as_transcription(merged_segments) def align_timestamps(segments, dia_seg_last_end): aligned_segments = [] previous_end = 0.0 previous_end_before_align = 0.0 for segment in segments: start, end = segment['timestamp'] lenght = end - start adjust = 0.0 if start < previous_end: if start == 0.0: start += previous_end end += previous_end adjust += previous_end else: start += previous_end - previous_end_before_align end += previous_end - previous_end_before_align adjust += previous_end - previous_end_before_align if end == None: end = dia_seg_last_end aligned_segments.append({'timestamp': (start, end), 'text': segment['text']}) previous_end = end previous_end_before_align = end - adjust return aligned_segments def transcribe_diarize(audio): sr, data = audio print(sr) processed_data = np.array(data).astype(np.float32) / 32767.0 waveform_tensor = torch.tensor(processed_data[np.newaxis, :]) transcription_res = transcribe(sr, data) # results from the diarization pipeline diarization_res = diarization_pipeline( #Test directly with processed_data {"waveform": waveform_tensor, "sample_rate": sr} {"waveform": waveform_tensor, "sample_rate": sr} ) dia_seg, dia_label = [], [] for segment, _, label in diarization_res.itertracks(yield_label=True): dia_seg.append([segment.start, segment.end]) dia_label.append(label) dia_seg_last_end = segment.end assert ( dia_seg ), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process." segmented_preds = align_timestamps(transcription_res["chunks"], dia_seg_last_end) dia_seg = np.array(dia_seg) asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds]) asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=0.0, no_match_label="NO_SPEAKER") for i, label in enumerate(asr_labels): segmented_preds[i]["speaker"] = label # Get diarization information starts, ends, speakers = diarization_info(diarization_res) # results from the transcription pipeline diarized_transcription = merge_segments(segmented_preds) # Plot diarization diarization_plot = plot_diarization(starts, ends, speakers) return transcription_res, diarized_transcription, diarization_plot # creating the gradio interface demo = gr.Interface( fn=transcribe_diarize, inputs=gr.Audio(sources=["upload", "microphone"]), outputs=[ gr.Textbox(lines=3, label="Text Transcription"), gr.Textbox(label="Diarized Transcription"), gr.Plot(label="Visualization"), ], examples=["sample1.wav"], title="Automatic Speech Recognition with Diarization 🗣️", description="Whisper V3 Large & Pyannote Speaker Diarization V3.1 \nTranscribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️", ) if __name__ == "__main__": demo.launch()