Spaces:
Running
Running
jhj0517
commited on
Commit
•
e284444
1
Parent(s):
8a43431
Add gradio parameter `file_format` to cache
Browse files
modules/whisper/base_transcription_pipeline.py
CHANGED
@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
71 |
def run(self,
|
72 |
audio: Union[str, BinaryIO, np.ndarray],
|
73 |
progress: gr.Progress = gr.Progress(),
|
|
|
74 |
add_timestamp: bool = True,
|
75 |
*pipeline_params,
|
76 |
) -> Tuple[List[Segment], float]:
|
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
|
|
86 |
Audio input. This can be file path or binary type.
|
87 |
progress: gr.Progress
|
88 |
Indicator to show progress directly in gradio.
|
|
|
|
|
89 |
add_timestamp: bool
|
90 |
Whether to add a timestamp at the end of the filename.
|
91 |
*pipeline_params: tuple
|
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
168 |
|
169 |
self.cache_parameters(
|
170 |
params=params,
|
|
|
171 |
add_timestamp=add_timestamp
|
172 |
)
|
173 |
return result, elapsed_time
|
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
224 |
transcribed_segments, time_for_task = self.run(
|
225 |
file,
|
226 |
progress,
|
|
|
227 |
add_timestamp,
|
228 |
*pipeline_params,
|
229 |
)
|
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
298 |
transcribed_segments, time_for_task = self.run(
|
299 |
mic_audio,
|
300 |
progress,
|
|
|
301 |
add_timestamp,
|
302 |
*pipeline_params,
|
303 |
)
|
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
364 |
transcribed_segments, time_for_task = self.run(
|
365 |
audio,
|
366 |
progress,
|
|
|
367 |
add_timestamp,
|
368 |
*pipeline_params,
|
369 |
)
|
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
|
|
513 |
@staticmethod
|
514 |
def cache_parameters(
|
515 |
params: TranscriptionPipelineParams,
|
516 |
-
|
|
|
517 |
):
|
518 |
"""Cache parameters to the yaml file"""
|
519 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
521 |
|
522 |
cached_yaml = {**cached_params, **param_to_cache}
|
523 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
|
|
524 |
|
525 |
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
526 |
if supress_token and isinstance(supress_token, list):
|
|
|
71 |
def run(self,
|
72 |
audio: Union[str, BinaryIO, np.ndarray],
|
73 |
progress: gr.Progress = gr.Progress(),
|
74 |
+
file_format: str = "SRT",
|
75 |
add_timestamp: bool = True,
|
76 |
*pipeline_params,
|
77 |
) -> Tuple[List[Segment], float]:
|
|
|
87 |
Audio input. This can be file path or binary type.
|
88 |
progress: gr.Progress
|
89 |
Indicator to show progress directly in gradio.
|
90 |
+
file_format: str
|
91 |
+
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
92 |
add_timestamp: bool
|
93 |
Whether to add a timestamp at the end of the filename.
|
94 |
*pipeline_params: tuple
|
|
|
171 |
|
172 |
self.cache_parameters(
|
173 |
params=params,
|
174 |
+
file_format=file_format,
|
175 |
add_timestamp=add_timestamp
|
176 |
)
|
177 |
return result, elapsed_time
|
|
|
228 |
transcribed_segments, time_for_task = self.run(
|
229 |
file,
|
230 |
progress,
|
231 |
+
file_format,
|
232 |
add_timestamp,
|
233 |
*pipeline_params,
|
234 |
)
|
|
|
303 |
transcribed_segments, time_for_task = self.run(
|
304 |
mic_audio,
|
305 |
progress,
|
306 |
+
file_format,
|
307 |
add_timestamp,
|
308 |
*pipeline_params,
|
309 |
)
|
|
|
370 |
transcribed_segments, time_for_task = self.run(
|
371 |
audio,
|
372 |
progress,
|
373 |
+
file_format,
|
374 |
add_timestamp,
|
375 |
*pipeline_params,
|
376 |
)
|
|
|
520 |
@staticmethod
|
521 |
def cache_parameters(
|
522 |
params: TranscriptionPipelineParams,
|
523 |
+
file_format: str = "SRT",
|
524 |
+
add_timestamp: bool = True
|
525 |
):
|
526 |
"""Cache parameters to the yaml file"""
|
527 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
|
|
529 |
|
530 |
cached_yaml = {**cached_params, **param_to_cache}
|
531 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
532 |
+
cached_yaml["whisper"]["file_format"] = file_format
|
533 |
|
534 |
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
535 |
if supress_token and isinstance(supress_token, list):
|