jhj0517 commited on
Commit
e5ef0df
·
1 Parent(s): 19c3dbd

Add `to_yaml()`

Browse files
modules/whisper/whisper_parameter.py CHANGED
@@ -1,6 +1,7 @@
1
  from dataclasses import dataclass, fields
2
  import gradio as gr
3
  from typing import Optional
 
4
 
5
 
6
  @dataclass
@@ -274,4 +275,54 @@ class WhisperValues:
274
  language_detection_segments: int
275
  """
276
  A data class to use Whisper parameters.
277
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass, fields
2
  import gradio as gr
3
  from typing import Optional
4
+ import yaml
5
 
6
 
7
  @dataclass
 
275
  language_detection_segments: int
276
  """
277
  A data class to use Whisper parameters.
278
+ """
279
+
280
+ def to_yaml(self) -> str:
281
+ data = {
282
+ "whisper": {
283
+ "model_size": self.model_size,
284
+ "lang": self.lang,
285
+ "is_translate": self.is_translate,
286
+ "beam_size": self.beam_size,
287
+ "log_prob_threshold": self.log_prob_threshold,
288
+ "no_speech_threshold": self.no_speech_threshold,
289
+ "best_of": self.best_of,
290
+ "patience": self.patience,
291
+ "condition_on_previous_text": self.condition_on_previous_text,
292
+ "prompt_reset_on_temperature": self.prompt_reset_on_temperature,
293
+ "initial_prompt": self.initial_prompt,
294
+ "temperature": self.temperature,
295
+ "compression_ratio_threshold": self.compression_ratio_threshold,
296
+ "chunk_length_s": self.chunk_length_s,
297
+ "batch_size": self.batch_size,
298
+ "length_penalty": self.length_penalty,
299
+ "repetition_penalty": self.repetition_penalty,
300
+ "no_repeat_ngram_size": self.no_repeat_ngram_size,
301
+ "prefix": self.prefix,
302
+ "suppress_blank": self.suppress_blank,
303
+ "suppress_tokens": self.suppress_tokens,
304
+ "max_initial_timestamp": self.max_initial_timestamp,
305
+ "word_timestamps": self.word_timestamps,
306
+ "prepend_punctuations": self.prepend_punctuations,
307
+ "append_punctuations": self.append_punctuations,
308
+ "max_new_tokens": self.max_new_tokens,
309
+ "chunk_length": self.chunk_length,
310
+ "hallucination_silence_threshold": self.hallucination_silence_threshold,
311
+ "hotwords": self.hotwords,
312
+ "language_detection_threshold": self.language_detection_threshold,
313
+ "language_detection_segments": self.language_detection_segments,
314
+ },
315
+ "vad": {
316
+ "vad_filter": self.vad_filter,
317
+ "threshold": self.threshold,
318
+ "min_speech_duration_ms": self.min_speech_duration_ms,
319
+ "max_speech_duration_s": self.max_speech_duration_s,
320
+ "min_silence_duration_ms": self.min_silence_duration_ms,
321
+ "speech_pad_ms": self.speech_pad_ms,
322
+ },
323
+ "diarization": {
324
+ "is_diarize": self.is_diarize,
325
+ "hf_token": self.hf_token
326
+ }
327
+ }
328
+ return yaml.dump(data, default_flow_style=False)