Diarization / app.py
DrishtiSharma's picture
Update app.py
7c97c61
#importing all the necessary packages
import torch
import transformers
import gradio as gr
from torchaudio.sox_effects import apply_effects_file
from termcolor import colored
from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForAudioFrameClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
# Defines the effects to apply to the audio file
EFFECTS = [
['remix', '-'], # merge all the channels
["channels", "1"], #channel-->mono
["rate", "16000"], # resample to 16000 Hz
["gain", "-1.0"], #Attenuation -1 dB
["silence", "1", "0.1", "0.1%", "-1", "0.1", "0.1%"],
#['pad', '0', '1.5'], # add 1.5 seconds silence at the end
['trim', '0', '10'], # get the first 10 seconds
]
THRESHOLD = 0.85 #depends on dataset
model_name = "microsoft/unispeech-sat-base-sd"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = UniSpeechSatForAudioFrameClassification.from_pretrained(model_name).to(device)
def fn(path):
#Applying the effects to the audio input file
wav, _ = apply_effects_file(path, EFFECTS)
#Extracting features
input = feature_extractor(wav.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device)
with torch.no_grad():
logits = model(input).logits
logits = logits.to(device)
probabilities = torch.sigmoid(logits[0])
# labels is a one-hot array of shape (num_frames, num_speakers)
labels = (probabilities > 0.5).long()
return labels
inputs = [
gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"),
]
output = gr.outputs.Textbox(label="Output Text")
gr.Interface(
fn=fn,
inputs=inputs,
outputs=output,
theme = "grass",
title="Speaker diarization using UniSpeech-SAT and X-Vectors").launch(enable_queue=True)