Diarization / app.py
Bindu36's picture
Update app.py
1b22e95 verified
raw
history blame contribute delete
No virus
2.9 kB
import gradio as gr
import whisper
import datetime
import torch
import subprocess
import os
from pyannote.audio import Audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.core import Segment
import wave
import contextlib
from sklearn.cluster import AgglomerativeClustering
import numpy as np
# Load Whisper model
model_size = "medium.en"
model = whisper.load_model(model_size)
audio = Audio()
embedding_model = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb", device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def transcribe_and_diarize(audio_file, num_speakers=2):
try:
path = audio_file.name
# Convert to WAV if necessary
if not path.endswith('.wav'):
subprocess.call(['ffmpeg', '-i', path, 'audio.wav', '-y'])
path = 'audio.wav'
# Transcribe audio
result = model.transcribe(path)
segments = result["segments"]
# Get audio duration
with contextlib.closing(wave.open(path, 'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
# Define function to extract segment embeddings
def segment_embedding(segment):
start = segment["start"]
end = min(duration, segment["end"])
clip = Segment(start, end)
waveform, sample_rate = audio.crop(path, clip)
return embedding_model(waveform[None])
# Extract embeddings for each segment
embeddings = np.zeros(shape=(len(segments), 192))
for i, segment in enumerate(segments):
embeddings[i] = segment_embedding(segment)
embeddings = np.nan_to_num(embeddings)
# Perform speaker clustering
clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
labels = clustering.labels_
for i in range(len(segments)):
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
# Generate transcript
transcript = ""
for i, segment in enumerate(segments):
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
transcript += "\n" + segment["speaker"] + ' ' + str(datetime.timedelta(seconds=round(segment["start"]))) + '\n'
transcript += segment["text"][1:] + ' '
transcript += "\n\n"
return transcript
except Exception as e:
return f"An error occurred: {str(e)}"
iface = gr.Interface(
fn=transcribe_and_diarize,
inputs=[
gr.Audio(type="filepath", label="Upload Audio File"),
gr.Number(value=2, label="Number of Speakers")
],
outputs="text",
title="Audio Transcription and Speaker Diarization",
description="Upload an audio file to get a transcription with speaker diarization."
)
iface.launch()