#importing all the necessary packages import torch import transformers import gradio as gr from torchaudio.sox_effects import apply_effects_file from termcolor import colored from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForAudioFrameClassification device = "cuda" if torch.cuda.is_available() else "cpu" # Defines the effects to apply to the audio file EFFECTS = [ ['remix', '-'], # merge all the channels ["channels", "1"], #channel-->mono ["rate", "16000"], # resample to 16000 Hz ["gain", "-1.0"], #Attenuation -1 dB ["silence", "1", "0.1", "0.1%", "-1", "0.1", "0.1%"], #['pad', '0', '1.5'], # add 1.5 seconds silence at the end ['trim', '0', '10'], # get the first 10 seconds ] THRESHOLD = 0.85 #depends on dataset model_name = "microsoft/unispeech-sat-base-sd" feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) model = UniSpeechSatForAudioFrameClassification.from_pretrained(model_name).to(device) def fn(path): #Applying the effects to the audio input file wav, _ = apply_effects_file(path, EFFECTS) #Extracting features input = feature_extractor(wav.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device) with torch.no_grad(): logits = model(input).logits logits = logits.to(device) probabilities = torch.sigmoid(logits[0]) # labels is a one-hot array of shape (num_frames, num_speakers) labels = (probabilities > 0.5).long() return labels inputs = [ gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"), ] output = gr.outputs.Textbox(label="Output Text") gr.Interface( fn=fn, inputs=inputs, outputs=output, theme = "grass", title="Speaker diarization using UniSpeech-SAT and X-Vectors").launch(enable_queue=True)