Spaces:
Runtime error
Runtime error
File size: 4,783 Bytes
3a3a6e5 5a1ff9c 3a3a6e5 daf6cbd 3a3a6e5 5a1ff9c 3a3a6e5 5a1ff9c 3a3a6e5 5a1ff9c 3a3a6e5 5a1ff9c 3a3a6e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
import plotly.graph_objects as go
from .utils.diarize_utils import match_segments
load_dotenv()
# Check and set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Model and pipeline setup
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY")
)
# returns diarization info such as segment start and end times, and speaker id
def diarization_info(res):
starts = []
ends = []
speakers = []
for segment, _, speaker in res.itertracks(yield_label=True):
starts.append(segment.start)
ends.append(segment.end)
speakers.append(speaker)
return starts, ends, speakers
# plot diarization results on a graph
def plot_diarization(starts, ends, speakers):
fig = go.Figure()
# Define a color map for different speakers
num_speakers = len(set(speakers))
colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)]
# Plot each segment with its speaker's color
for start, end, speaker in zip(starts, ends, speakers):
speaker_id = list(set(speakers)).index(speaker)
fig.add_trace(
go.Scatter(
x=[start, end],
y=[speaker_id, speaker_id],
mode="lines",
line=dict(color=colors[speaker_id], width=15),
showlegend=False,
)
)
fig.update_layout(
title="Speaker Diarization",
xaxis=dict(title="Time"),
yaxis=dict(title="Speaker"),
height=600,
width=800,
)
return fig
def transcribe(sr, data):
processed_data = np.array(data).astype(np.float32) / 32767.0
# results from the pipeline
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data},return_timestamps=True)
return transcription_res
def transcribe_diarize(audio):
sr, data = audio
processed_data = np.array(data).astype(np.float32) / 32767.0
waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
transcription_res = transcribe(sr, data)
# results from the diarization pipeline
diarization_res = diarization_pipeline(
{"waveform": waveform_tensor, "sample_rate": sr}
)
dia_seg, dia_label = [], []
for segment, _, label in diarization_res.itertracks(yield_label=True):
dia_seg.append([segment.start, segment.end])
dia_label.append(label)
assert (
dia_seg
), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process."
segmented_preds = transcription_res["chunks"]
dia_seg = np.array(dia_seg)
asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds])
asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=0.0, no_match_label="NO_SPEAKER")
for i, label in enumerate(asr_labels):
segmented_preds[i]["speaker"] = label
# Get diarization information
starts, ends, speakers = diarization_info(diarization_res)
# results from the transcription pipeline
diarized_transcription = segmented_preds
# Plot diarization
diarization_plot = plot_diarization(starts, ends, speakers)
return transcription_res, diarized_transcription, diarization_plot
# creating the gradio interface
demo = gr.Interface(
fn=transcribe_diarize,
inputs=gr.Audio(sources=["upload", "microphone"]),
outputs=[
gr.Textbox(lines=3, label="Text Transcription"),
gr.Textbox(label="Diarized Transcription"),
gr.Plot(label="Visualization"),
],
examples=["sample1.wav"],
title="Automatic Speech Recognition with Diarization 🗣️",
description="Whisper V3 Large & Pyannote Speaker Diarization V3.1 \nTranscribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️",
)
if __name__ == "__main__":
demo.launch()
|