SpeakerDiarization / app_test.py
LTPhat's picture
update 1/9/24
a38fdca
import streamlit as st
import numpy as np
import os
import whisper
from sklearn.cluster import AgglomerativeClustering, SpectralClustering
import torch
import librosa
from torch.utils.data import DataLoader
from mix_sae import MoESparseAutoencodersCL
from load_dataset import AutoEncoderDataset
import argparse
UPLOAD_FOLDER = "./uploads"
parser = argparse.ArgumentParser(description='Deep Clustering Network')
parser.add_argument('--input_dim', type=int, default=384,
help='input dimension')
# Model parameters
parser.add_argument('--lr', type=float, default=1e-3,
help='learning rate (default: 1e-4)')
parser.add_argument('--wd', type=float, default=1e-4,
help='weight decay (default: 5e-4)')
parser.add_argument('--batch-size', type=int, default=16,
help='input batch size for training')
parser.add_argument('--lamda', type=float, default=1,
help='coefficient of the reconstruction loss')
parser.add_argument('--beta', type=float, default=1,
help=('coefficient of the regularization term on '
'clustering'))
parser.add_argument('--hidden-dims', default=[256, 128, 64, 32],
help='learning rate (default: 1e-4)')
parser.add_argument('--latent_dim', type=int, default=2,
help='latent space dimension')
parser.add_argument('--n-clusters', type=int, default=2,
help='number of clusters in the latent space')
parser.add_argument('--input-dim', type=int, default=384,
help='input dimension')
parser.add_argument('--n-classes', type=int, default=2,
help='output dimension')
parser.add_argument('--pretrain_epochs', type=int, default=50,
help='pretraining step epochs')
parser.add_argument('--pretrain_epochs_main', type=int, default=30,
help='pretraining step epochs')
parser.add_argument('--pretrain', type=bool, default=True,
help='whether use pre-training')
parser.add_argument('--main_train_epochs', type=int, default=5,
help='main_train epochs')
parser.add_argument('--rho', type=float, default=0.2,
help='whether use pre-training')
parser.add_argument('--sparsity_param', type=float, default=0.2,
help='sparsity constract param')
parser.add_argument('--cl_loss_param', type=float, default=0.05,
help='clasification loss param')
args = parser.parse_args()
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ['wav']
def process_wav(audio_file, speaker_number, model_type, run_device = 'cpu', sr = 16000):
embedding_dims = {"tiny": 384, 'small': 768, 'base': 512, 'medium':1024}
#---- get results from whisper model
whisper_model = whisper.load_model(model_type, run_device)
wp_results = whisper_model.transcribe(audio_file)
for ide in range(len(wp_results['segments'])):
del wp_results['segments'][ide]['seek']
del wp_results['segments'][ide]['tokens']
del wp_results['segments'][ide]['compression_ratio']
del wp_results['segments'][ide]['temperature']
del wp_results['segments'][ide]['avg_logprob']
del wp_results['segments'][ide]['no_speech_prob']
#---- solve each segment
segments = wp_results["segments"]
# >= 2 sentences
if len(segments) > 1:
embeddings = np.zeros(shape=(len(segments), embedding_dims[model_type]))
for i, segment in enumerate(segments):
start = int(segment["start"] * sr)
end = int(segment["end"] * sr)
# Extract a segment
audio = audio_file[start: end]
mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
#--- this code to create the correct shape of mel spectrogram
while True:
nF, nT = np.shape(mel)
if nT > 3000:
mel = mel[:,0:3000]
break
else:
mel = torch.cat((mel, mel), -1)
mel = torch.unsqueeze(mel, 0)
wp_emb = whisper_model.embed_audio(mel)
#print(np.shape(wp_emb))
emb_1d = np.mean(wp_emb.cpu().detach().numpy(), axis=0)
emb_1d = np.mean(emb_1d, axis=0)
#print(np.shape(emb_1d))
#exit()
embeddings[i] = emb_1d
embeddings= np.array(embeddings, dtype="f")
train_loader = AutoEncoderDataset(embeddings)
train_loader = DataLoader(train_loader, batch_size = args.batch_size, shuffle = False)
moe_cl = MoESparseAutoencodersCL(args=args)
mixture_moe_cl, full_latent_X = moe_cl.pretraining(train_loader)
pre_label = moe_cl.psedo_label
mixture_moe_cl = moe_cl.main_training(train_loader)
moe_cl_pred = moe_cl.get_final_cluster(train_loader)
#--- clustering spk emb
# clustering = AgglomerativeClustering(speaker_number).fit(embeddings)
# labels = clustering.labels_
for i in range(len(segments)):
wp_results['segments'][i]["speaker"] = 'SPEAKER ' + str(pre_label[i] + 1)
# only one sentence
else:
wp_results['segments'][0]["speaker"] = 'SPEAKER 1'
return wp_results
def main():
title_style = """
<style>
.title {
text-align: center;
font-size: 40px;
}
</style>
"""
st.markdown(
title_style,
unsafe_allow_html=True
)
title = """
<h1 class = "title" >Telephone Calls Speaker Diarization</h1>
</div>
"""
st.markdown(title,
unsafe_allow_html=True)
# st.title("Speaker Diarization")
# Get user inputs
file = st.file_uploader("Upload a WAV file:", type=["wav"])
num_speakers = st.number_input("Number of speakers:", min_value=2, max_value=2)
model_list = ['tiny', 'small']
model_type = st.selectbox("Select model type: ", model_list)
# Display the result
st.write("Your uploaded wav file: ")
st.audio(file, format = 'audio/wav')
if st.button("Submit"):
if file is not None:
# Read audio file using pydub
audio_file, _ = librosa.load(file, sr=16000)
# Process the uploaded file using the AI model
wp_results = process_wav(audio_file, num_speakers, model_type)
# Write result:
st.write("Segments:" )
for seg in wp_results['segments']:
seg['start'] = np.round(seg['start'], 1)
seg['end'] = np.round(seg['end'], 1)
st.write(seg)
st.write("Language: ", wp_results['language'])
st.write("Full text:")
st.write(wp_results['text'])
else:
print("Error")
st.write("\n\n---\n\n")
st.write("Built with Docker and Streamlit")
st.link_button("Paper link: https://arxiv.org/abs/2407.01963", "https://arxiv.org/abs/2407.01963")
return
if __name__ == "__main__":
main()