vumichien commited on
Commit
19f7e21
1 Parent(s): c422373

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -9,6 +9,7 @@ import time
9
  import os
10
  import numpy as np
11
  from sklearn.cluster import AgglomerativeClustering
 
12
 
13
  from pytube import YouTube
14
  import torch
@@ -191,7 +192,7 @@ def get_youtube(video_url):
191
  print(abs_video_path)
192
  return abs_video_path
193
 
194
- def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_speakers):
195
  """
196
  # Transcribe youtube link using OpenAI Whisper
197
  1. Using Open AI's Whisper model to seperate audio into segments and generate transcripts.
@@ -249,8 +250,21 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
249
  embeddings = np.nan_to_num(embeddings)
250
  print(f'Embedding shape: {embeddings.shape}')
251
 
252
- # Assign speaker label
253
- clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  labels = clustering.labels_
255
  for i in range(len(segments)):
256
  segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
@@ -289,7 +303,7 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
289
  save_path = "output/transcript_result.csv"
290
  df_results = pd.DataFrame(objects)
291
  df_results.to_csv(save_path)
292
- return df_results, system_info, save_path
293
 
294
  except Exception as e:
295
  raise RuntimeError("Error Running inference with local model", e)
@@ -303,7 +317,8 @@ df_init = pd.DataFrame(columns=['Start', 'End', 'Speaker', 'Text'])
303
  memory = psutil.virtual_memory()
304
  selected_source_lang = gr.Dropdown(choices=source_language_list, type="value", value="en", label="Spoken language in video", interactive=True)
305
  selected_whisper_model = gr.Dropdown(choices=whisper_models, type="value", value="base", label="Selected Whisper model", interactive=True)
306
- number_speakers = gr.Number(precision=0, value=2, label="Selected number of speakers", interactive=True)
 
307
  system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*")
308
  download_transcript = gr.File(label="Download transcript")
309
  transcription_df = gr.DataFrame(value=df_init,label="Transcription dataframe", row_count=(0, "dynamic"), max_rows = 10, wrap=True, overflow_row_behaviour='paginate')
@@ -356,14 +371,17 @@ with demo:
356
  gr.Markdown('''
357
  ##### Here you can start the transcription process.
358
  ##### Please select the source language for transcription.
359
- ##### You should select a number of speakers for getting better results.
360
  ''')
361
  selected_source_lang.render()
362
  selected_whisper_model.render()
363
- number_speakers.render()
 
364
  transcribe_btn = gr.Button("Transcribe audio and diarization")
365
- transcribe_btn.click(speech_to_text, [video_in, selected_source_lang, selected_whisper_model, number_speakers], [transcription_df, system_info, download_transcript])
366
-
 
 
367
 
368
  with gr.Row():
369
  gr.Markdown('''
 
9
  import os
10
  import numpy as np
11
  from sklearn.cluster import AgglomerativeClustering
12
+ from sklearn.metrics import silhouette_score
13
 
14
  from pytube import YouTube
15
  import torch
 
192
  print(abs_video_path)
193
  return abs_video_path
194
 
195
+ def speech_to_text(video_file_path, selected_source_lang, whisper_model, min_num_speakers, max_number_speakers):
196
  """
197
  # Transcribe youtube link using OpenAI Whisper
198
  1. Using Open AI's Whisper model to seperate audio into segments and generate transcripts.
 
250
  embeddings = np.nan_to_num(embeddings)
251
  print(f'Embedding shape: {embeddings.shape}')
252
 
253
+ # Find the best number of speakers
254
+ if min_num_speakers > max_number_speakers:
255
+ min_speakers = max_number_speakers
256
+ max_speakers = min_num_speakers
257
+ score_num_speakers = {}
258
+
259
+ for num_speakers in range(min_speakers, max_speakers+1):
260
+ clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
261
+ score = silhouette_score(embeddings, clustering.labels_, metric='euclidean')
262
+ score_num_speakers[num_speakers] = score
263
+ best_num_speaker = max(score_num_speakers, key=lambda x:score_num_speakers[x])
264
+ print(f"The best number of speakers: {best_num_speaker} with {score_num_speakers[best_num_speaker]} score")
265
+
266
+ # Assign speaker label
267
+ clustering = AgglomerativeClustering(best_num_speaker).fit(embeddings)
268
  labels = clustering.labels_
269
  for i in range(len(segments)):
270
  segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
 
303
  save_path = "output/transcript_result.csv"
304
  df_results = pd.DataFrame(objects)
305
  df_results.to_csv(save_path)
306
+ return df_results, system_info, save_pathassuming
307
 
308
  except Exception as e:
309
  raise RuntimeError("Error Running inference with local model", e)
 
317
  memory = psutil.virtual_memory()
318
  selected_source_lang = gr.Dropdown(choices=source_language_list, type="value", value="en", label="Spoken language in video", interactive=True)
319
  selected_whisper_model = gr.Dropdown(choices=whisper_models, type="value", value="base", label="Selected Whisper model", interactive=True)
320
+ input_min_number_speakers = gr.Number(precision=0, value=2, label="Select assumed minimum number of speakers", interactive=True)
321
+ input_max_number_speakers = gr.Number(precision=0, value=2, label="Select assumed maximum number of speakers", interactive=True)
322
  system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*")
323
  download_transcript = gr.File(label="Download transcript")
324
  transcription_df = gr.DataFrame(value=df_init,label="Transcription dataframe", row_count=(0, "dynamic"), max_rows = 10, wrap=True, overflow_row_behaviour='paginate')
 
371
  gr.Markdown('''
372
  ##### Here you can start the transcription process.
373
  ##### Please select the source language for transcription.
374
+ ##### You can select a range of assumed numbers of speakers.
375
  ''')
376
  selected_source_lang.render()
377
  selected_whisper_model.render()
378
+ input_min_number_speakers.render()
379
+ input_min_number_speakers.render()
380
  transcribe_btn = gr.Button("Transcribe audio and diarization")
381
+ transcribe_btn.click(speech_to_text,
382
+ [video_in, selected_source_lang, selected_whisper_model, input_min_number_speakers, input_min_number_speakers],
383
+ [transcription_df, system_info, download_transcript]
384
+ )
385
 
386
  with gr.Row():
387
  gr.Markdown('''