jhj0517 commited on
Commit
9fdfe53
·
1 Parent(s): dbca1ee

Rename test script file

Browse files
tests/{transcription.py → test_transcription.py} RENAMED
@@ -1,7 +1,9 @@
1
  from modules.whisper.whisper_factory import WhisperFactory
2
  from modules.whisper.whisper_parameter import WhisperValues
 
3
  from test_config import *
4
 
 
5
  import pytest
6
  import gradio as gr
7
  import os
@@ -9,39 +11,63 @@ import os
9
 
10
  @pytest.mark.parametrize("whisper_type", ["whisper", "faster-whisper", "insanely_fast_whisper"])
11
  def test_transcribe(whisper_type: str):
12
- audio_path = os.path.join("test.wav")
 
13
  if not os.path.exists(audio_path):
14
- download_file(TEST_FILE_DOWNLOAD_URL, audio_path)
15
 
16
  whisper_inferencer = WhisperFactory.create_whisper_inference(
17
  whisper_type=whisper_type,
18
  )
19
-
20
  print("Device : ", whisper_inferencer.device)
21
 
22
  hparams = WhisperValues(
23
  model_size=TEST_WHISPER_MODEL,
24
  ).as_list()
25
 
26
- whisper_inferencer.transcribe_file(
27
- files=[audio_path],
28
- progress=gr.Progress(),
 
 
 
29
  *hparams,
30
  )
31
 
 
 
 
32
  whisper_inferencer.transcribe_youtube(
33
- youtube_link=TEST_YOUTUBE_URL,
34
- progress=gr.Progress(),
 
 
35
  *hparams,
36
  )
 
 
37
 
38
  whisper_inferencer.transcribe_mic(
39
- mic_audio=audio_path,
40
- progress=gr.Progress(),
 
 
41
  *hparams,
42
  )
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
44
 
45
- def download_file(url: str, path: str):
46
- if not os.path.exists(path):
47
- os.system(f"wget {url} -O {path}")
 
1
  from modules.whisper.whisper_factory import WhisperFactory
2
  from modules.whisper.whisper_parameter import WhisperValues
3
+ from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
6
+ import requests
7
  import pytest
8
  import gradio as gr
9
  import os
 
11
 
12
  @pytest.mark.parametrize("whisper_type", ["whisper", "faster-whisper", "insanely_fast_whisper"])
13
  def test_transcribe(whisper_type: str):
14
+ audio_path_dir = os.path.join(WEBUI_DIR, "tests")
15
+ audio_path = os.path.join(audio_path_dir, "jfk.wav")
16
  if not os.path.exists(audio_path):
17
+ download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
18
 
19
  whisper_inferencer = WhisperFactory.create_whisper_inference(
20
  whisper_type=whisper_type,
21
  )
 
22
  print("Device : ", whisper_inferencer.device)
23
 
24
  hparams = WhisperValues(
25
  model_size=TEST_WHISPER_MODEL,
26
  ).as_list()
27
 
28
+ subtitle_str, file_path = whisper_inferencer.transcribe_file(
29
+ [audio_path],
30
+ None,
31
+ "SRT",
32
+ False,
33
+ gr.Progress(),
34
  *hparams,
35
  )
36
 
37
+ assert isinstance(subtitle_str, str) and subtitle_str
38
+ assert isinstance(file_path[0], str) and file_path
39
+
40
  whisper_inferencer.transcribe_youtube(
41
+ TEST_YOUTUBE_URL,
42
+ "SRT",
43
+ False,
44
+ gr.Progress(),
45
  *hparams,
46
  )
47
+ assert isinstance(subtitle_str, str) and subtitle_str
48
+ assert isinstance(file_path[0], str) and file_path
49
 
50
  whisper_inferencer.transcribe_mic(
51
+ audio_path,
52
+ "SRT",
53
+ False,
54
+ gr.Progress(),
55
  *hparams,
56
  )
57
+ assert isinstance(subtitle_str, str) and subtitle_str
58
+ assert isinstance(file_path[0], str) and file_path
59
+
60
+
61
+ def download_file(url, save_dir):
62
+ if not os.path.exists(save_dir):
63
+ os.makedirs(save_dir)
64
+
65
+ file_name = url.split("/")[-1]
66
+ file_path = os.path.join(save_dir, file_name)
67
+
68
+ response = requests.get(url)
69
 
70
+ with open(file_path, "wb") as file:
71
+ file.write(response.content)
72
 
73
+ print(f"File downloaded to: {file_path}")