Yuhan-Lu commited on
Commit
e90d25c
1 Parent(s): d231d79

fix logging path; enable CUDA for whispher

Browse files

Former-commit-id: 0ea328b0a70e95f6061b93083ed418eafa4857c8

Files changed (1) hide show
  1. pipeline.py +8 -2
pipeline.py CHANGED
@@ -9,6 +9,7 @@ import whisper
9
  from srt2ass import srt2ass
10
  import logging
11
  from datetime import datetime
 
12
 
13
  import subprocess
14
 
@@ -109,7 +110,10 @@ def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file =
109
 
110
  # use stable-whisper
111
  elif method == "stable":
112
- model = stable_whisper.load_model(whisper_model)
 
 
 
113
  transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
114
  (
115
  transcript
@@ -265,7 +269,9 @@ def main():
265
 
266
  audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
267
 
268
- logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")))], encoding='utf-8')
 
 
269
  logging.info("---------------------Video Info---------------------")
270
  logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
271
 
 
9
  from srt2ass import srt2ass
10
  import logging
11
  from datetime import datetime
12
+ import torch
13
 
14
  import subprocess
15
 
 
110
 
111
  # use stable-whisper
112
  elif method == "stable":
113
+
114
+ # use cuda if available
115
+ devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
116
+ model = stable_whisper.load_model(whisper_model, device = devices)
117
  transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
118
  (
119
  transcript
 
269
 
270
  audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
271
 
272
+ if not os.path.exists(args.log_dir):
273
+ os.makedirs(args.log_dir)
274
+ logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")), 'w', encoding='utf-8')])
275
  logging.info("---------------------Video Info---------------------")
276
  logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
277