Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import numpy as np | |
from pyannote.audio import Pipeline | |
import os | |
from dotenv import load_dotenv | |
import plotly.graph_objects as go | |
from utils.diarize_utils import match_segments | |
load_dotenv() | |
# Check and set device | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Model and pipeline setup | |
model_id = "openai/whisper-large-v3" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
diarization_pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY") | |
) | |
# returns diarization info such as segment start and end times, and speaker id | |
def diarization_info(res): | |
starts = [] | |
ends = [] | |
speakers = [] | |
for segment, _, speaker in res.itertracks(yield_label=True): | |
starts.append(segment.start) | |
ends.append(segment.end) | |
speakers.append(speaker) | |
return starts, ends, speakers | |
# plot diarization results on a graph | |
def plot_diarization(starts, ends, speakers): | |
fig = go.Figure() | |
# Define a color map for different speakers | |
num_speakers = len(set(speakers)) | |
colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)] | |
# Plot each segment with its speaker's color | |
for start, end, speaker in zip(starts, ends, speakers): | |
speaker_id = list(set(speakers)).index(speaker) | |
fig.add_trace( | |
go.Scatter( | |
x=[start, end], | |
y=[speaker_id, speaker_id], | |
mode="lines", | |
line=dict(color=colors[speaker_id], width=15), | |
showlegend=False, | |
) | |
) | |
fig.update_layout( | |
title="Speaker Diarization", | |
xaxis=dict(title="Time"), | |
yaxis=dict(title="Speaker"), | |
height=600, | |
width=800, | |
) | |
return fig | |
def transcribe(sr, data): | |
processed_data = np.array(data).astype(np.float32) / 32767.0 | |
# results from the pipeline | |
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data},return_timestamps=True) | |
inputs = processor(processed_data, sampling_rate=sr, return_tensors="pt") | |
input_features = inputs.input_features | |
generated_ids = model.generate(inputs=input_features,return_token_timestamps=True,return_timestamps=True) | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=False) | |
print(transcription) | |
return transcription_res | |
def format_string(timestamp): | |
""" | |
Reformat a timestamp string from (HH:)MM:SS to float seconds. Note that the hour column | |
is optional, and is appended within the function if not input. | |
Args: | |
timestamp (str): | |
Timestamp in string format, either MM:SS or HH:MM:SS. | |
Returns: | |
seconds (float): | |
Total seconds corresponding to the input timestamp. | |
""" | |
split_time = timestamp.split(":") | |
split_time = [float(sub_time) for sub_time in split_time] | |
if len(split_time) == 2: | |
split_time.insert(0, 0) | |
seconds = split_time[0] * 3600 + split_time[1] * 60 + split_time[2] | |
return seconds | |
# Adapted from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50 | |
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."): | |
""" | |
Reformat a timestamp from a float of seconds to a string in format (HH:)MM:SS. Note that the hour | |
column is optional, and is appended in the function if the number of hours > 0. | |
Args: | |
seconds (float): | |
Total seconds corresponding to the input timestamp. | |
Returns: | |
timestamp (str): | |
Timestamp in string format, either MM:SS or HH:MM:SS. | |
""" | |
if seconds is not None: | |
milliseconds = round(seconds * 1000.0) | |
hours = milliseconds // 3_600_000 | |
milliseconds -= hours * 3_600_000 | |
minutes = milliseconds // 60_000 | |
milliseconds -= minutes * 60_000 | |
seconds = milliseconds // 1_000 | |
milliseconds -= seconds * 1_000 | |
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" | |
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" | |
else: | |
# we have a malformed timestamp so just return it as is | |
return seconds | |
def format_as_transcription(raw_segments): | |
return "\n\n".join( | |
[ | |
f"{chunk['speaker']} [{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" | |
for chunk in raw_segments | |
] | |
) | |
def merge_segments(segments): | |
merged_segments = [] | |
current_segment = segments[0] | |
for segment in segments[1:]: | |
if segment['speaker'] == current_segment['speaker'] and current_segment['timestamp'][1] == segment['timestamp'][0]: | |
current_segment['timestamp'] = (current_segment['timestamp'][0], segment['timestamp'][1]) | |
current_segment['text'] += ' ' + segment['text'] | |
else: | |
merged_segments.append(current_segment) | |
current_segment = segment | |
merged_segments.append(current_segment) | |
return format_as_transcription(merged_segments) | |
def align_timestamps(segments, dia_seg_last_end): | |
aligned_segments = [] | |
previous_end = 0.0 | |
for segment in segments: | |
start, end = segment['timestamp'] | |
if start < previous_end: | |
end -= start | |
start = previous_end | |
end += previous_end | |
if end == None: | |
end = dia_seg_last_end | |
aligned_segments.append({'timestamp': (start, end), 'text': segment['text']}) | |
previous_end = end | |
return aligned_segments | |
def transcribe_diarize(audio): | |
sr, data = audio | |
processed_data = np.array(data).astype(np.float32) / 32767.0 | |
waveform_tensor = torch.tensor(processed_data[np.newaxis, :]) | |
transcription_res = transcribe(sr, data) | |
# results from the diarization pipeline | |
diarization_res = diarization_pipeline( | |
{"waveform": waveform_tensor, "sample_rate": sr} | |
) | |
dia_seg, dia_label = [], [] | |
for segment, _, label in diarization_res.itertracks(yield_label=True): | |
dia_seg.append([segment.start, segment.end]) | |
dia_label.append(label) | |
dia_seg_last_end = segment.end | |
assert ( | |
dia_seg | |
), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process." | |
segmented_preds = align_timestamps(transcription_res["chunks"], dia_seg_last_end) | |
dia_seg = np.array(dia_seg) | |
asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds]) | |
asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=0.0, no_match_label="NO_SPEAKER") | |
for i, label in enumerate(asr_labels): | |
segmented_preds[i]["speaker"] = label | |
# Get diarization information | |
starts, ends, speakers = diarization_info(diarization_res) | |
# results from the transcription pipeline | |
diarized_transcription = merge_segments(segmented_preds) | |
# Plot diarization | |
diarization_plot = plot_diarization(starts, ends, speakers) | |
return transcription_res, diarized_transcription, diarization_plot | |
# creating the gradio interface | |
demo = gr.Interface( | |
fn=transcribe_diarize, | |
inputs=gr.Audio(sources=["upload", "microphone"]), | |
outputs=[ | |
gr.Textbox(lines=3, label="Text Transcription"), | |
gr.Textbox(label="Diarized Transcription"), | |
gr.Plot(label="Visualization"), | |
], | |
examples=["sample1.wav"], | |
title="Automatic Speech Recognition with Diarization 🗣️", | |
description="Whisper V3 Large & Pyannote Speaker Diarization V3.1 \nTranscribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |