jhj0517 commited on
Commit
e284444
1 Parent(s): 8a43431

Add gradio parameter `file_format` to cache

Browse files
modules/whisper/base_transcription_pipeline.py CHANGED
@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
71
  def run(self,
72
  audio: Union[str, BinaryIO, np.ndarray],
73
  progress: gr.Progress = gr.Progress(),
 
74
  add_timestamp: bool = True,
75
  *pipeline_params,
76
  ) -> Tuple[List[Segment], float]:
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
86
  Audio input. This can be file path or binary type.
87
  progress: gr.Progress
88
  Indicator to show progress directly in gradio.
 
 
89
  add_timestamp: bool
90
  Whether to add a timestamp at the end of the filename.
91
  *pipeline_params: tuple
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
168
 
169
  self.cache_parameters(
170
  params=params,
 
171
  add_timestamp=add_timestamp
172
  )
173
  return result, elapsed_time
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
224
  transcribed_segments, time_for_task = self.run(
225
  file,
226
  progress,
 
227
  add_timestamp,
228
  *pipeline_params,
229
  )
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
298
  transcribed_segments, time_for_task = self.run(
299
  mic_audio,
300
  progress,
 
301
  add_timestamp,
302
  *pipeline_params,
303
  )
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
364
  transcribed_segments, time_for_task = self.run(
365
  audio,
366
  progress,
 
367
  add_timestamp,
368
  *pipeline_params,
369
  )
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
513
  @staticmethod
514
  def cache_parameters(
515
  params: TranscriptionPipelineParams,
516
- add_timestamp: bool
 
517
  ):
518
  """Cache parameters to the yaml file"""
519
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
521
 
522
  cached_yaml = {**cached_params, **param_to_cache}
523
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
 
524
 
525
  supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
526
  if supress_token and isinstance(supress_token, list):
 
71
  def run(self,
72
  audio: Union[str, BinaryIO, np.ndarray],
73
  progress: gr.Progress = gr.Progress(),
74
+ file_format: str = "SRT",
75
  add_timestamp: bool = True,
76
  *pipeline_params,
77
  ) -> Tuple[List[Segment], float]:
 
87
  Audio input. This can be file path or binary type.
88
  progress: gr.Progress
89
  Indicator to show progress directly in gradio.
90
+ file_format: str
91
+ Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
92
  add_timestamp: bool
93
  Whether to add a timestamp at the end of the filename.
94
  *pipeline_params: tuple
 
171
 
172
  self.cache_parameters(
173
  params=params,
174
+ file_format=file_format,
175
  add_timestamp=add_timestamp
176
  )
177
  return result, elapsed_time
 
228
  transcribed_segments, time_for_task = self.run(
229
  file,
230
  progress,
231
+ file_format,
232
  add_timestamp,
233
  *pipeline_params,
234
  )
 
303
  transcribed_segments, time_for_task = self.run(
304
  mic_audio,
305
  progress,
306
+ file_format,
307
  add_timestamp,
308
  *pipeline_params,
309
  )
 
370
  transcribed_segments, time_for_task = self.run(
371
  audio,
372
  progress,
373
+ file_format,
374
  add_timestamp,
375
  *pipeline_params,
376
  )
 
520
  @staticmethod
521
  def cache_parameters(
522
  params: TranscriptionPipelineParams,
523
+ file_format: str = "SRT",
524
+ add_timestamp: bool = True
525
  ):
526
  """Cache parameters to the yaml file"""
527
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
 
529
 
530
  cached_yaml = {**cached_params, **param_to_cache}
531
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
532
+ cached_yaml["whisper"]["file_format"] = file_format
533
 
534
  supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
535
  if supress_token and isinstance(supress_token, list):