|
import streamlit as st |
|
import numpy as np |
|
import os |
|
import whisper |
|
from sklearn.cluster import AgglomerativeClustering, SpectralClustering |
|
import torch |
|
import librosa |
|
from torch.utils.data import DataLoader |
|
from mix_sae import MoESparseAutoencodersCL |
|
from load_dataset import AutoEncoderDataset |
|
import argparse |
|
|
|
|
|
UPLOAD_FOLDER = "./uploads" |
|
parser = argparse.ArgumentParser(description='Deep Clustering Network') |
|
parser.add_argument('--input_dim', type=int, default=384, |
|
help='input dimension') |
|
|
|
parser.add_argument('--lr', type=float, default=1e-3, |
|
help='learning rate (default: 1e-4)') |
|
parser.add_argument('--wd', type=float, default=1e-4, |
|
help='weight decay (default: 5e-4)') |
|
parser.add_argument('--batch-size', type=int, default=16, |
|
help='input batch size for training') |
|
parser.add_argument('--lamda', type=float, default=1, |
|
help='coefficient of the reconstruction loss') |
|
parser.add_argument('--beta', type=float, default=1, |
|
help=('coefficient of the regularization term on ' |
|
'clustering')) |
|
parser.add_argument('--hidden-dims', default=[256, 128, 64, 32], |
|
help='learning rate (default: 1e-4)') |
|
parser.add_argument('--latent_dim', type=int, default=2, |
|
help='latent space dimension') |
|
parser.add_argument('--n-clusters', type=int, default=2, |
|
help='number of clusters in the latent space') |
|
parser.add_argument('--input-dim', type=int, default=384, |
|
help='input dimension') |
|
parser.add_argument('--n-classes', type=int, default=2, |
|
help='output dimension') |
|
parser.add_argument('--pretrain_epochs', type=int, default=50, |
|
help='pretraining step epochs') |
|
parser.add_argument('--pretrain_epochs_main', type=int, default=30, |
|
help='pretraining step epochs') |
|
parser.add_argument('--pretrain', type=bool, default=True, |
|
help='whether use pre-training') |
|
parser.add_argument('--main_train_epochs', type=int, default=5, |
|
help='main_train epochs') |
|
parser.add_argument('--rho', type=float, default=0.2, |
|
help='whether use pre-training') |
|
parser.add_argument('--sparsity_param', type=float, default=0.2, |
|
help='sparsity constract param') |
|
parser.add_argument('--cl_loss_param', type=float, default=0.05, |
|
help='clasification loss param') |
|
args = parser.parse_args() |
|
|
|
|
|
def allowed_file(filename): |
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ['wav'] |
|
|
|
|
|
def process_wav(audio_file, speaker_number, model_type, run_device = 'cpu', sr = 16000): |
|
embedding_dims = {"tiny": 384, 'small': 768, 'base': 512, 'medium':1024} |
|
|
|
whisper_model = whisper.load_model(model_type, run_device) |
|
wp_results = whisper_model.transcribe(audio_file) |
|
for ide in range(len(wp_results['segments'])): |
|
del wp_results['segments'][ide]['seek'] |
|
del wp_results['segments'][ide]['tokens'] |
|
del wp_results['segments'][ide]['compression_ratio'] |
|
del wp_results['segments'][ide]['temperature'] |
|
del wp_results['segments'][ide]['avg_logprob'] |
|
del wp_results['segments'][ide]['no_speech_prob'] |
|
|
|
|
|
segments = wp_results["segments"] |
|
|
|
|
|
if len(segments) > 1: |
|
embeddings = np.zeros(shape=(len(segments), embedding_dims[model_type])) |
|
|
|
for i, segment in enumerate(segments): |
|
start = int(segment["start"] * sr) |
|
end = int(segment["end"] * sr) |
|
|
|
|
|
audio = audio_file[start: end] |
|
mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device) |
|
|
|
|
|
while True: |
|
nF, nT = np.shape(mel) |
|
if nT > 3000: |
|
mel = mel[:,0:3000] |
|
break |
|
else: |
|
mel = torch.cat((mel, mel), -1) |
|
mel = torch.unsqueeze(mel, 0) |
|
wp_emb = whisper_model.embed_audio(mel) |
|
|
|
|
|
emb_1d = np.mean(wp_emb.cpu().detach().numpy(), axis=0) |
|
emb_1d = np.mean(emb_1d, axis=0) |
|
|
|
|
|
embeddings[i] = emb_1d |
|
embeddings= np.array(embeddings, dtype="f") |
|
train_loader = AutoEncoderDataset(embeddings) |
|
train_loader = DataLoader(train_loader, batch_size = args.batch_size, shuffle = False) |
|
|
|
|
|
moe_cl = MoESparseAutoencodersCL(args=args) |
|
mixture_moe_cl, full_latent_X = moe_cl.pretraining(train_loader) |
|
pre_label = moe_cl.psedo_label |
|
mixture_moe_cl = moe_cl.main_training(train_loader) |
|
moe_cl_pred = moe_cl.get_final_cluster(train_loader) |
|
|
|
|
|
|
|
|
|
for i in range(len(segments)): |
|
wp_results['segments'][i]["speaker"] = 'SPEAKER ' + str(pre_label[i] + 1) |
|
|
|
|
|
else: |
|
wp_results['segments'][0]["speaker"] = 'SPEAKER 1' |
|
|
|
return wp_results |
|
|
|
|
|
def main(): |
|
|
|
title_style = """ |
|
<style> |
|
.title { |
|
text-align: center; |
|
font-size: 40px; |
|
} |
|
</style> |
|
""" |
|
st.markdown( |
|
title_style, |
|
unsafe_allow_html=True |
|
) |
|
title = """ |
|
<h1 class = "title" >Telephone Calls Speaker Diarization</h1> |
|
</div> |
|
""" |
|
st.markdown(title, |
|
unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
file = st.file_uploader("Upload a WAV file:", type=["wav"]) |
|
num_speakers = st.number_input("Number of speakers:", min_value=2, max_value=2) |
|
|
|
model_list = ['tiny', 'small'] |
|
model_type = st.selectbox("Select model type: ", model_list) |
|
|
|
|
|
st.write("Your uploaded wav file: ") |
|
st.audio(file, format = 'audio/wav') |
|
if st.button("Submit"): |
|
if file is not None: |
|
|
|
|
|
audio_file, _ = librosa.load(file, sr=16000) |
|
|
|
|
|
wp_results = process_wav(audio_file, num_speakers, model_type) |
|
|
|
|
|
st.write("Segments:" ) |
|
for seg in wp_results['segments']: |
|
seg['start'] = np.round(seg['start'], 1) |
|
seg['end'] = np.round(seg['end'], 1) |
|
st.write(seg) |
|
st.write("Language: ", wp_results['language']) |
|
st.write("Full text:") |
|
st.write(wp_results['text']) |
|
else: |
|
print("Error") |
|
st.write("\n\n---\n\n") |
|
st.write("Built with Docker and Streamlit") |
|
st.link_button("Paper link: https://arxiv.org/abs/2407.01963", "https://arxiv.org/abs/2407.01963") |
|
return |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|