File size: 3,714 Bytes
64425f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77bc1e4
 
 
 
 
 
 
78d53bd
 
77bc1e4
64425f4
77bc1e4
64425f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c997cb
64425f4
0c997cb
64425f4
0c997cb
64425f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97821d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
A Gradio app to transcribe and diarize a podcast using Whisper and pyannote. Adapted from Dwarkesh Patel's Colab notebook here:
https://colab.research.google.com/drive/1V-Bt5Hm2kjaDb4P1RyMSswsDKyrzc2-3?usp=sharing
"""
import whisper
import datetime

import subprocess
import torch
import gradio as gr
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import wave
import contextlib

from sklearn.cluster import AgglomerativeClustering
import numpy as np

if torch.cuda.is_available():
    device_type = "cuda"
elif torch.backends.mps.is_available():
    device_type = "mps"
else:
    device_type = "cpu"

print(f"chosen device: {device_type}")


embedding_model = PretrainedSpeakerEmbedding(
    "speechbrain/spkrec-ecapa-voxceleb", device=torch.device(device_type)
)

audio = Audio()


def time(secs):
    return datetime.timedelta(seconds=round(secs))


def segment_embedding(segment, duration, audio, path):
    start = segment["start"]
    # Whisper overshoots the end timestamp in the last segment
    end = min(duration, segment["end"])
    clip = Segment(start, end)
    waveform, sample_rate = audio.crop(path, clip)
    return embedding_model(waveform[None])


def get_whisper_results(path, model_type):
    model = whisper.load_model(model_type)
    result = model.transcribe(path)
    segments = result["segments"]

    with contextlib.closing(wave.open(path, "r")) as f:
        frames = f.getnframes()
        rate = f.getframerate()
        duration = frames / float(rate)

    return result, segments, frames, rate, duration


def cluster_embeddings(segments, duration, path, num_speakers):
    embeddings = np.zeros(shape=(len(segments), 192))
    for i, segment in enumerate(segments):
        embeddings[i] = segment_embedding(segment, duration, audio, path)

    embeddings = np.nan_to_num(embeddings)

    clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
    labels = clustering.labels_
    for i in range(len(segments)):
        segments[i]["speaker"] = "SPEAKER " + str(labels[i] + 1)


def transcribe(path, model_type, num_speakers):
    if path[-3:] != "wav":
        subprocess.call(["ffmpeg", "-i", path, "audio.wav", "-y"])
        path = "audio.wav"

    ret = ""
    print("running whisper...")
    result, segments, frames, rate, duration = get_whisper_results(path, model_type)
    print("done running whisper. Clustering embeddings...")
    cluster_embeddings(segments, duration, path, num_speakers)
    print(f"done clustering embeddings. Time to return...")

    for i, segment in enumerate(segments):
        if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
            ret += "\n" + segment["speaker"] + " " + str(time(segment["start"])) + "\n"
        ret += segment["text"][1:] + " "

    return ret


if __name__ == "__main__":
    interface = gr.Interface(
        fn=transcribe,
        inputs=[
            gr.File(file_count="single", label="Upload an audio file"),
            gr.Radio(
                choices=["tiny", "base", "small", "medium", "large-v3"],
                value="large-v3",
                type="value",
                label="Model size",
            ),
            gr.Number(
                value=2,
                label="Number of speakers",
            ),
        ],
        outputs=gr.Textbox(label="Transcript", show_copy_button=True),
        title="Transcribe a podcast!",
        description="Upload an audio file and choose a model size and number of speakers on the left, then click submit to transcribe!",
        theme=gr.themes.Soft(),
    )
    interface.launch(share=True)