""" A Gradio app to transcribe and diarize a podcast using Whisper and pyannote. Adapted from Dwarkesh Patel's Colab notebook here: https://colab.research.google.com/drive/1V-Bt5Hm2kjaDb4P1RyMSswsDKyrzc2-3?usp=sharing """ import whisper import datetime import subprocess import torch import gradio as gr import pyannote.audio from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from pyannote.audio import Audio from pyannote.core import Segment import wave import contextlib from sklearn.cluster import AgglomerativeClustering import numpy as np if torch.cuda.is_available(): device_type = "cuda" elif torch.backends.mps.is_available(): device_type = "mps" else: device_type = "cpu" print(f"chosen device: {device_type}") embedding_model = PretrainedSpeakerEmbedding( "speechbrain/spkrec-ecapa-voxceleb", device=torch.device(device_type) ) audio = Audio() def time(secs): return datetime.timedelta(seconds=round(secs)) def segment_embedding(segment, duration, audio, path): start = segment["start"] # Whisper overshoots the end timestamp in the last segment end = min(duration, segment["end"]) clip = Segment(start, end) waveform, sample_rate = audio.crop(path, clip) return embedding_model(waveform[None]) def get_whisper_results(path, model_type): model = whisper.load_model(model_type) result = model.transcribe(path) segments = result["segments"] with contextlib.closing(wave.open(path, "r")) as f: frames = f.getnframes() rate = f.getframerate() duration = frames / float(rate) return result, segments, frames, rate, duration def cluster_embeddings(segments, duration, path, num_speakers): embeddings = np.zeros(shape=(len(segments), 192)) for i, segment in enumerate(segments): embeddings[i] = segment_embedding(segment, duration, audio, path) embeddings = np.nan_to_num(embeddings) clustering = AgglomerativeClustering(num_speakers).fit(embeddings) labels = clustering.labels_ for i in range(len(segments)): segments[i]["speaker"] = "SPEAKER " + str(labels[i] + 1) def transcribe(path, model_type, num_speakers): if path[-3:] != "wav": subprocess.call(["ffmpeg", "-i", path, "audio.wav", "-y"]) path = "audio.wav" ret = "" print("running whisper...") result, segments, frames, rate, duration = get_whisper_results(path, model_type) print("done running whisper. Clustering embeddings...") cluster_embeddings(segments, duration, path, num_speakers) print(f"done clustering embeddings. Time to return...") for i, segment in enumerate(segments): if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: ret += "\n" + segment["speaker"] + " " + str(time(segment["start"])) + "\n" ret += segment["text"][1:] + " " return ret if __name__ == "__main__": interface = gr.Interface( fn=transcribe, inputs=[ gr.File(file_count="single", label="Upload an audio file"), gr.Radio( choices=["tiny", "base", "small", "medium", "large-v3"], value="large-v3", type="value", label="Model size", ), gr.Number( value=2, label="Number of speakers", ), ], outputs=gr.Textbox(label="Transcript", show_copy_button=True), title="Transcribe a podcast!", description="Upload an audio file and choose a model size and number of speakers on the left, then click submit to transcribe!", theme=gr.themes.Soft(), ) interface.launch(share=True)