romsyflux
Revert test removing waveform_tensor and using data as input
ba7591a
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
import plotly.graph_objects as go
from utils.diarize_utils import match_segments
load_dotenv()
# Check and set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Model and pipeline setup
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.set_prefix_tokens(predict_timestamps=True)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY")
)
# returns diarization info such as segment start and end times, and speaker id
def diarization_info(res):
starts = []
ends = []
speakers = []
for segment, _, speaker in res.itertracks(yield_label=True):
starts.append(segment.start)
ends.append(segment.end)
speakers.append(speaker)
return starts, ends, speakers
# plot diarization results on a graph
def plot_diarization(starts, ends, speakers):
fig = go.Figure()
# Define a color map for different speakers
num_speakers = len(set(speakers))
colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)]
# Plot each segment with its speaker's color
for start, end, speaker in zip(starts, ends, speakers):
speaker_id = list(set(speakers)).index(speaker)
fig.add_trace(
go.Scatter(
x=[start, end],
y=[speaker_id, speaker_id],
mode="lines",
line=dict(color=colors[speaker_id], width=15),
showlegend=False,
)
)
fig.update_layout(
title="Speaker Diarization",
xaxis=dict(title="Time"),
yaxis=dict(title="Speaker"),
height=600,
width=800,
)
return fig
def transcribe(sr, data):
processed_data = np.array(data).astype(np.float32) / 32767.0
# results from the pipeline
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data},return_timestamps=True)
return transcription_res
def format_string(timestamp):
"""
Reformat a timestamp string from (HH:)MM:SS to float seconds. Note that the hour column
is optional, and is appended within the function if not input.
Args:
timestamp (str):
Timestamp in string format, either MM:SS or HH:MM:SS.
Returns:
seconds (float):
Total seconds corresponding to the input timestamp.
"""
split_time = timestamp.split(":")
split_time = [float(sub_time) for sub_time in split_time]
if len(split_time) == 2:
split_time.insert(0, 0)
seconds = split_time[0] * 3600 + split_time[1] * 60 + split_time[2]
return seconds
# Adapted from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
"""
Reformat a timestamp from a float of seconds to a string in format (HH:)MM:SS. Note that the hour
column is optional, and is appended in the function if the number of hours > 0.
Args:
seconds (float):
Total seconds corresponding to the input timestamp.
Returns:
timestamp (str):
Timestamp in string format, either MM:SS or HH:MM:SS.
"""
if seconds is not None:
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
else:
# we have a malformed timestamp so just return it as is
return seconds
def format_as_transcription(raw_segments):
return "\n\n".join(
[
f"{chunk['speaker']} [{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
for chunk in raw_segments
]
)
def merge_segments(segments):
merged_segments = []
current_segment = segments[0]
for segment in segments[1:]:
if segment['speaker'] == current_segment['speaker'] and current_segment['timestamp'][1] == segment['timestamp'][0]:
current_segment['timestamp'] = (current_segment['timestamp'][0], segment['timestamp'][1])
current_segment['text'] += ' ' + segment['text']
else:
merged_segments.append(current_segment)
current_segment = segment
merged_segments.append(current_segment)
return format_as_transcription(merged_segments)
def align_timestamps(segments, dia_seg_last_end):
aligned_segments = []
previous_end = 0.0
previous_end_before_align = 0.0
for segment in segments:
start, end = segment['timestamp']
lenght = end - start
adjust = 0.0
if start < previous_end:
if start == 0.0:
start += previous_end
end += previous_end
adjust += previous_end
else:
start += previous_end - previous_end_before_align
end += previous_end - previous_end_before_align
adjust += previous_end - previous_end_before_align
if end == None:
end = dia_seg_last_end
aligned_segments.append({'timestamp': (start, end), 'text': segment['text']})
previous_end = end
previous_end_before_align = end - adjust
return aligned_segments
def transcribe_diarize(audio):
sr, data = audio
print(sr)
processed_data = np.array(data).astype(np.float32) / 32767.0
waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
transcription_res = transcribe(sr, data)
# results from the diarization pipeline
diarization_res = diarization_pipeline(
#Test directly with processed_data {"waveform": waveform_tensor, "sample_rate": sr}
{"waveform": waveform_tensor, "sample_rate": sr}
)
dia_seg, dia_label = [], []
for segment, _, label in diarization_res.itertracks(yield_label=True):
dia_seg.append([segment.start, segment.end])
dia_label.append(label)
dia_seg_last_end = segment.end
assert (
dia_seg
), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process."
segmented_preds = align_timestamps(transcription_res["chunks"], dia_seg_last_end)
dia_seg = np.array(dia_seg)
asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds])
asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=0.0, no_match_label="NO_SPEAKER")
for i, label in enumerate(asr_labels):
segmented_preds[i]["speaker"] = label
# Get diarization information
starts, ends, speakers = diarization_info(diarization_res)
# results from the transcription pipeline
diarized_transcription = merge_segments(segmented_preds)
# Plot diarization
diarization_plot = plot_diarization(starts, ends, speakers)
return transcription_res, diarized_transcription, diarization_plot
# creating the gradio interface
demo = gr.Interface(
fn=transcribe_diarize,
inputs=gr.Audio(sources=["upload", "microphone"]),
outputs=[
gr.Textbox(lines=3, label="Text Transcription"),
gr.Textbox(label="Diarized Transcription"),
gr.Plot(label="Visualization"),
],
examples=["sample1.wav"],
title="Automatic Speech Recognition with Diarization 🗣️",
description="Whisper V3 Large & Pyannote Speaker Diarization V3.1 \nTranscribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️",
)
if __name__ == "__main__":
demo.launch()