Xingyu Bian
added diarization plot
9437579
raw
history blame
No virus
3.51 kB
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
import plotly.graph_objects as go
load_dotenv()
# Check and set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Model and pipeline setup
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
# diarization pipeline (renamed to avoid conflict)
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0", use_auth_token=os.getenv("HF_KEY")
)
def diarization_info(res):
starts = []
ends = []
speakers = []
for segment, track, _ in res.itertracks(yield_label=True):
starts.append(segment.start)
ends.append(segment.end)
speakers.append(track)
return starts, ends, speakers
def plot_diarization(starts, ends, speakers):
fig = go.Figure()
# Define a color map for different speakers
num_speakers = len(set(speakers))
colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)]
# Plot each segment with its speaker's color
for start, end, speaker in zip(starts, ends, speakers):
speaker_id = list(set(speakers)).index(speaker)
fig.add_trace(
go.Scatter(
x=[start, end],
y=[speaker_id, speaker_id],
mode="lines",
line=dict(color=colors[speaker_id], width=15),
showlegend=False,
)
)
fig.update_layout(
title="Speaker Diarization",
xaxis=dict(title="Time"),
yaxis=dict(title="Speaker"),
height=600,
width=800,
)
return fig
def transcribe_diarize(audio):
sr, data = audio
processed_data = np.array(data).astype(np.float32) / 32767.0
waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
# results from the pipeline
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"]
diarization_res = diarization_pipeline(
{"waveform": waveform_tensor, "sample_rate": sr}
)
# Get diarization information
starts, ends, speakers = diarization_info(diarization_res)
# Plot diarization
diarization_plot = plot_diarization(starts, ends, speakers)
return transcription_res, diarization_res, diarization_plot
# creating the gradio interface
demo = gr.Interface(
fn=transcribe_diarize,
inputs=gr.Audio(sources=["upload", "microphone"]),
outputs=[
gr.Textbox(lines=3, label="Text Transcription"),
gr.Textbox(label="Speaker Diarization"),
gr.Plot(),
],
title="Automatic Speech Recognition with Diarization 🗣️",
description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file 🎙️",
)
if __name__ == "__main__":
demo.launch()