File size: 4,783 Bytes
3a3a6e5
 
 
 
 
 
 
 
5a1ff9c
3a3a6e5
 
 
 
 
 
 
 
 
daf6cbd
3a3a6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a1ff9c
3a3a6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a1ff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a3a6e5
 
 
 
5a1ff9c
3a3a6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a1ff9c
 
3a3a6e5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
import plotly.graph_objects as go
from .utils.diarize_utils import match_segments

load_dotenv()

# Check and set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Model and pipeline setup
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

diarization_pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY")
)


# returns diarization info such as segment start and end times, and speaker id
def diarization_info(res):
    starts = []
    ends = []
    speakers = []

    for segment, _, speaker in res.itertracks(yield_label=True):
        starts.append(segment.start)
        ends.append(segment.end)
        speakers.append(speaker)

    return starts, ends, speakers


# plot diarization results on a graph
def plot_diarization(starts, ends, speakers):
    fig = go.Figure()

    # Define a color map for different speakers
    num_speakers = len(set(speakers))
    colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)]

    # Plot each segment with its speaker's color
    for start, end, speaker in zip(starts, ends, speakers):
        speaker_id = list(set(speakers)).index(speaker)
        fig.add_trace(
            go.Scatter(
                x=[start, end],
                y=[speaker_id, speaker_id],
                mode="lines",
                line=dict(color=colors[speaker_id], width=15),
                showlegend=False,
            )
        )

    fig.update_layout(
        title="Speaker Diarization",
        xaxis=dict(title="Time"),
        yaxis=dict(title="Speaker"),
        height=600,
        width=800,
    )

    return fig


def transcribe(sr, data):
    processed_data = np.array(data).astype(np.float32) / 32767.0

    # results from the pipeline
    transcription_res = pipe({"sampling_rate": sr, "raw": processed_data},return_timestamps=True)

    return transcription_res


def transcribe_diarize(audio):
    sr, data = audio
    processed_data = np.array(data).astype(np.float32) / 32767.0
    waveform_tensor = torch.tensor(processed_data[np.newaxis, :])

    transcription_res = transcribe(sr, data)

    # results from the diarization pipeline
    diarization_res = diarization_pipeline(
        {"waveform": waveform_tensor, "sample_rate": sr}
    )
    dia_seg, dia_label = [], []
    for segment, _, label in diarization_res.itertracks(yield_label=True):
        dia_seg.append([segment.start, segment.end])
        dia_label.append(label)
    assert (
            dia_seg
        ), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process."
    segmented_preds = transcription_res["chunks"]
    dia_seg = np.array(dia_seg)
    asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds])

    asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=0.0, no_match_label="NO_SPEAKER")

    for i, label in enumerate(asr_labels):
        segmented_preds[i]["speaker"] = label
    # Get diarization information
    starts, ends, speakers = diarization_info(diarization_res)

    # results from the transcription pipeline
    diarized_transcription = segmented_preds

    # Plot diarization
    diarization_plot = plot_diarization(starts, ends, speakers)

    return transcription_res, diarized_transcription, diarization_plot


# creating the gradio interface
demo = gr.Interface(
    fn=transcribe_diarize,
    inputs=gr.Audio(sources=["upload", "microphone"]),
    outputs=[
        gr.Textbox(lines=3, label="Text Transcription"),
        gr.Textbox(label="Diarized Transcription"),
        gr.Plot(label="Visualization"),
    ],
    examples=["sample1.wav"],
    title="Automatic Speech Recognition with Diarization 🗣️",
    description="Whisper V3 Large & Pyannote Speaker Diarization V3.1 \nTranscribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️",
)

if __name__ == "__main__":
    demo.launch()