aadnk commited on
Commit
af09fe7
·
1 Parent(s): f641aea

Set default initial prompt mode

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -273,15 +273,21 @@ class WhisperTranscriber:
273
  if ('task' in decodeOptions):
274
  task = decodeOptions.pop('task')
275
 
276
- if (vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
277
- vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
 
 
 
 
 
 
278
  # Prepend initial prompt
279
- prompt_strategy = PrependPromptStrategy(initial_prompt, vadOptions.vadInitialPromptMode)
280
  elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
281
  # Use a JSON format to specify the prompt for each segment
282
  prompt_strategy = JsonPromptStrategy(initial_prompt)
283
  else:
284
- raise ValueError("Invalid vadInitialPromptMode: " + vadOptions.vadInitialPromptMode)
285
 
286
  # Callable for processing an audio file
287
  whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
 
273
  if ('task' in decodeOptions):
274
  task = decodeOptions.pop('task')
275
 
276
+ initial_prompt_mode = vadOptions.vadInitialPromptMode
277
+
278
+ # Set default initial prompt mode
279
+ if (initial_prompt_mode is None):
280
+ initial_prompt_mode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT
281
+
282
+ if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
283
+ initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
284
  # Prepend initial prompt
285
+ prompt_strategy = PrependPromptStrategy(initial_prompt, initial_prompt_mode)
286
  elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
287
  # Use a JSON format to specify the prompt for each segment
288
  prompt_strategy = JsonPromptStrategy(initial_prompt)
289
  else:
290
+ raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
291
 
292
  # Callable for processing an audio file
293
  whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)