jhj0517 commited on
Commit
ddbe0b6
·
1 Parent(s): 78d8e18

Apply Segment model to the pipeline

Browse files
modules/diarize/diarize_pipeline.py CHANGED
@@ -44,6 +44,7 @@ class DiarizationPipeline:
44
  def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
45
  transcript_segments = transcript_result["segments"]
46
  for seg in transcript_segments:
 
47
  # assign speaker to segment (if any)
48
  diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
49
  seg['start'])
 
44
  def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
45
  transcript_segments = transcript_result["segments"]
46
  for seg in transcript_segments:
47
+ seg = seg.dict()
48
  # assign speaker to segment (if any)
49
  diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
50
  seg['start'])
modules/diarize/diarizer.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from typing import List, Union, BinaryIO, Optional
4
  import numpy as np
5
  import time
6
  import logging
@@ -8,6 +8,7 @@ import logging
8
  from modules.utils.paths import DIARIZATION_MODELS_DIR
9
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
10
  from modules.diarize.audio_loader import load_audio
 
11
 
12
 
13
  class Diarizer:
@@ -23,10 +24,10 @@ class Diarizer:
23
 
24
  def run(self,
25
  audio: Union[str, BinaryIO, np.ndarray],
26
- transcribed_result: List[dict],
27
  use_auth_token: str,
28
  device: Optional[str] = None
29
- ):
30
  """
31
  Diarize transcribed result as a post-processing
32
 
@@ -34,7 +35,7 @@ class Diarizer:
34
  ----------
35
  audio: Union[str, BinaryIO, np.ndarray]
36
  Audio input. This can be file path or binary type.
37
- transcribed_result: List[dict]
38
  transcribed result through whisper.
39
  use_auth_token: str
40
  Huggingface token with READ permission. This is only needed the first time you download the model.
@@ -44,8 +45,8 @@ class Diarizer:
44
 
45
  Returns
46
  ----------
47
- segments_result: List[dict]
48
- list of dicts that includes start, end timestamps and transcribed text
49
  elapsed_time: float
50
  elapsed time for running
51
  """
@@ -68,14 +69,21 @@ class Diarizer:
68
  {"segments": transcribed_result}
69
  )
70
 
 
71
  for segment in diarized_result["segments"]:
 
72
  speaker = "None"
73
  if "speaker" in segment:
74
  speaker = segment["speaker"]
75
- segment["text"] = speaker + "|" + segment["text"].strip()
 
 
 
 
 
76
 
77
  elapsed_time = time.time() - start_time
78
- return diarized_result["segments"], elapsed_time
79
 
80
  def update_pipe(self,
81
  use_auth_token: str,
 
1
  import os
2
  import torch
3
+ from typing import List, Union, BinaryIO, Optional, Tuple
4
  import numpy as np
5
  import time
6
  import logging
 
8
  from modules.utils.paths import DIARIZATION_MODELS_DIR
9
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
10
  from modules.diarize.audio_loader import load_audio
11
+ from modules.whisper.data_classes import *
12
 
13
 
14
  class Diarizer:
 
24
 
25
  def run(self,
26
  audio: Union[str, BinaryIO, np.ndarray],
27
+ transcribed_result: List[Segment],
28
  use_auth_token: str,
29
  device: Optional[str] = None
30
+ ) -> Tuple[List[Segment], float]:
31
  """
32
  Diarize transcribed result as a post-processing
33
 
 
35
  ----------
36
  audio: Union[str, BinaryIO, np.ndarray]
37
  Audio input. This can be file path or binary type.
38
+ transcribed_result: List[Segment]
39
  transcribed result through whisper.
40
  use_auth_token: str
41
  Huggingface token with READ permission. This is only needed the first time you download the model.
 
45
 
46
  Returns
47
  ----------
48
+ segments_result: List[Segment]
49
+ list of Segment that includes start, end timestamps and transcribed text
50
  elapsed_time: float
51
  elapsed time for running
52
  """
 
69
  {"segments": transcribed_result}
70
  )
71
 
72
+ segments_result = []
73
  for segment in diarized_result["segments"]:
74
+ segment = segment.dict()
75
  speaker = "None"
76
  if "speaker" in segment:
77
  speaker = segment["speaker"]
78
+ diarized_text = speaker + "|" + segment["text"].strip()
79
+ segments_result.append(Segment(
80
+ start=segment["start"],
81
+ end=segment["end"],
82
+ text=diarized_text
83
+ ))
84
 
85
  elapsed_time = time.time() - start_time
86
+ return segments_result, elapsed_time
87
 
88
  def update_pipe(self,
89
  use_auth_token: str,
modules/utils/subtitle_manager.py CHANGED
@@ -1,5 +1,7 @@
1
  import re
2
 
 
 
3
 
4
  def timeformat_srt(time):
5
  hours = time // 3600
@@ -23,6 +25,9 @@ def write_file(subtitle, output_file):
23
 
24
 
25
  def get_srt(segments):
 
 
 
26
  output = ""
27
  for i, segment in enumerate(segments):
28
  output += f"{i + 1}\n"
@@ -34,6 +39,9 @@ def get_srt(segments):
34
 
35
 
36
  def get_vtt(segments):
 
 
 
37
  output = "WEBVTT\n\n"
38
  for i, segment in enumerate(segments):
39
  output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
@@ -44,6 +52,9 @@ def get_vtt(segments):
44
 
45
 
46
  def get_txt(segments):
 
 
 
47
  output = ""
48
  for i, segment in enumerate(segments):
49
  if segment['text'].startswith(' '):
 
1
  import re
2
 
3
+ from modules.whisper.data_classes import Segment
4
+
5
 
6
  def timeformat_srt(time):
7
  hours = time // 3600
 
25
 
26
 
27
  def get_srt(segments):
28
+ if segments and isinstance(segments[0], Segment):
29
+ segments = [seg.dict() for seg in segments]
30
+
31
  output = ""
32
  for i, segment in enumerate(segments):
33
  output += f"{i + 1}\n"
 
39
 
40
 
41
  def get_vtt(segments):
42
+ if segments and isinstance(segments[0], Segment):
43
+ segments = [seg.dict() for seg in segments]
44
+
45
  output = "WEBVTT\n\n"
46
  for i, segment in enumerate(segments):
47
  output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
 
52
 
53
 
54
  def get_txt(segments):
55
+ if segments and isinstance(segments[0], Segment):
56
+ segments = [seg.dict() for seg in segments]
57
+
58
  output = ""
59
  for i, segment in enumerate(segments):
60
  if segment['text'].startswith(' '):
modules/vad/silero_vad.py CHANGED
@@ -5,7 +5,8 @@ import numpy as np
5
  from typing import BinaryIO, Union, List, Optional, Tuple
6
  import warnings
7
  import faster_whisper
8
- from faster_whisper.transcribe import SpeechTimestampsMap, Segment
 
9
  import gradio as gr
10
 
11
 
@@ -247,18 +248,18 @@ class SileroVAD:
247
 
248
  def restore_speech_timestamps(
249
  self,
250
- segments: List[dict],
251
  speech_chunks: List[dict],
252
  sampling_rate: Optional[int] = None,
253
- ) -> List[dict]:
254
  if sampling_rate is None:
255
  sampling_rate = self.sampling_rate
256
 
257
  ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
258
 
259
  for segment in segments:
260
- segment["start"] = ts_map.get_original_time(segment["start"])
261
- segment["end"] = ts_map.get_original_time(segment["end"])
262
 
263
  return segments
264
 
 
5
  from typing import BinaryIO, Union, List, Optional, Tuple
6
  import warnings
7
  import faster_whisper
8
+ from modules.whisper.data_classes import *
9
+ from faster_whisper.transcribe import SpeechTimestampsMap
10
  import gradio as gr
11
 
12
 
 
248
 
249
  def restore_speech_timestamps(
250
  self,
251
+ segments: List[Segment],
252
  speech_chunks: List[dict],
253
  sampling_rate: Optional[int] = None,
254
+ ) -> List[Segment]:
255
  if sampling_rate is None:
256
  sampling_rate = self.sampling_rate
257
 
258
  ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
259
 
260
  for segment in segments:
261
+ segment.start = ts_map.get_original_time(segment.start)
262
+ segment.start = ts_map.get_original_time(segment.start)
263
 
264
  return segments
265
 
modules/whisper/faster_whisper_inference.py CHANGED
@@ -40,7 +40,7 @@ class FasterWhisperInference(BaseTranscriptionPipeline):
40
  audio: Union[str, BinaryIO, np.ndarray],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
- ) -> Tuple[List[dict], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
@@ -55,8 +55,8 @@ class FasterWhisperInference(BaseTranscriptionPipeline):
55
 
56
  Returns
57
  ----------
58
- segments_result: List[dict]
59
- list of dicts that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
@@ -102,11 +102,11 @@ class FasterWhisperInference(BaseTranscriptionPipeline):
102
  segments_result = []
103
  for segment in segments:
104
  progress(segment.start / info.duration, desc="Transcribing..")
105
- segments_result.append({
106
- "start": segment.start,
107
- "end": segment.end,
108
- "text": segment.text
109
- })
110
 
111
  elapsed_time = time.time() - start_time
112
  return segments_result, elapsed_time
 
40
  audio: Union[str, BinaryIO, np.ndarray],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
+ ) -> Tuple[List[Segment], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
 
55
 
56
  Returns
57
  ----------
58
+ segments_result: List[Segment]
59
+ list of Segment that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
 
102
  segments_result = []
103
  for segment in segments:
104
  progress(segment.start / info.duration, desc="Transcribing..")
105
+ segments_result.append(Segment(
106
+ start=segment.start,
107
+ end=segment.end,
108
+ text=segment.text
109
+ ))
110
 
111
  elapsed_time = time.time() - start_time
112
  return segments_result, elapsed_time
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -40,7 +40,7 @@ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
40
  audio: Union[str, np.ndarray, torch.Tensor],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
- ) -> Tuple[List[dict], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
@@ -55,8 +55,8 @@ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
55
 
56
  Returns
57
  ----------
58
- segments_result: List[dict]
59
- list of dicts that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
@@ -95,9 +95,17 @@ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
95
  generate_kwargs=kwargs
96
  )
97
 
98
- segments_result = self.format_result(
99
- transcribed_result=segments,
100
- )
 
 
 
 
 
 
 
 
101
  elapsed_time = time.time() - start_time
102
  return segments_result, elapsed_time
103
 
 
40
  audio: Union[str, np.ndarray, torch.Tensor],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
+ ) -> Tuple[List[Segment], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
 
55
 
56
  Returns
57
  ----------
58
+ segments_result: List[Segment]
59
+ list of Segment that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
 
95
  generate_kwargs=kwargs
96
  )
97
 
98
+ segments_result = []
99
+ for item in segments["chunks"]:
100
+ start, end = item["timestamp"][0], item["timestamp"][1]
101
+ if end is None:
102
+ end = start
103
+ segments_result.append(Segment(
104
+ text=item["text"],
105
+ start=start,
106
+ end=end
107
+ ))
108
+
109
  elapsed_time = time.time() - start_time
110
  return segments_result, elapsed_time
111
 
modules/whisper/whisper_Inference.py CHANGED
@@ -30,7 +30,7 @@ class WhisperInference(BaseTranscriptionPipeline):
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
  progress: gr.Progress = gr.Progress(),
32
  *whisper_params,
33
- ) -> Tuple[List[dict], float]:
34
  """
35
  transcribe method for faster-whisper.
36
 
@@ -45,8 +45,8 @@ class WhisperInference(BaseTranscriptionPipeline):
45
 
46
  Returns
47
  ----------
48
- segments_result: List[dict]
49
- list of dicts that includes start, end timestamps and transcribed text
50
  elapsed_time: float
51
  elapsed time for transcription
52
  """
@@ -59,21 +59,28 @@ class WhisperInference(BaseTranscriptionPipeline):
59
  def progress_callback(progress_value):
60
  progress(progress_value, desc="Transcribing..")
61
 
62
- segments_result = self.model.transcribe(audio=audio,
63
- language=params.lang,
64
- verbose=False,
65
- beam_size=params.beam_size,
66
- logprob_threshold=params.log_prob_threshold,
67
- no_speech_threshold=params.no_speech_threshold,
68
- task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
69
- fp16=True if params.compute_type == "float16" else False,
70
- best_of=params.best_of,
71
- patience=params.patience,
72
- temperature=params.temperature,
73
- compression_ratio_threshold=params.compression_ratio_threshold,
74
- progress_callback=progress_callback,)["segments"]
75
- elapsed_time = time.time() - start_time
 
 
 
 
 
 
76
 
 
77
  return segments_result, elapsed_time
78
 
79
  def update_model(self,
 
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
  progress: gr.Progress = gr.Progress(),
32
  *whisper_params,
33
+ ) -> Tuple[List[Segment], float]:
34
  """
35
  transcribe method for faster-whisper.
36
 
 
45
 
46
  Returns
47
  ----------
48
+ segments_result: List[Segment]
49
+ list of Segment that includes start, end timestamps and transcribed text
50
  elapsed_time: float
51
  elapsed time for transcription
52
  """
 
59
  def progress_callback(progress_value):
60
  progress(progress_value, desc="Transcribing..")
61
 
62
+ result = self.model.transcribe(audio=audio,
63
+ language=params.lang,
64
+ verbose=False,
65
+ beam_size=params.beam_size,
66
+ logprob_threshold=params.log_prob_threshold,
67
+ no_speech_threshold=params.no_speech_threshold,
68
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
69
+ fp16=True if params.compute_type == "float16" else False,
70
+ best_of=params.best_of,
71
+ patience=params.patience,
72
+ temperature=params.temperature,
73
+ compression_ratio_threshold=params.compression_ratio_threshold,
74
+ progress_callback=progress_callback,)["segments"]
75
+ segments_result = []
76
+ for segment in result:
77
+ segments_result.append(Segment(
78
+ start=segment["start"],
79
+ end=segment["end"],
80
+ text=segment["text"]
81
+ ))
82
 
83
+ elapsed_time = time.time() - start_time
84
  return segments_result, elapsed_time
85
 
86
  def update_model(self,