Eason Lu commited on
Commit
cd67dcd
β€’
1 Parent(s): f144427

debug piepline

Browse files

Former-commit-id: 651fae66447937370d60073672ac834b0b2481de

configs/task_config.yaml CHANGED
@@ -8,3 +8,4 @@ output_type:
8
  source_lang: EN
9
  target_lang: ZH
10
  field: SC2
 
 
8
  source_lang: EN
9
  target_lang: ZH
10
  field: SC2
11
+ chunk_size: 1000
src/task.py CHANGED
@@ -11,7 +11,7 @@ import subprocess
11
  from src.srt_util.srt import SrtScript
12
  from src.srt_util.srt2ass import srt2ass
13
  from time import time, strftime, gmtime, sleep
14
- from translation.translation import get_translation, translate
15
 
16
  import torch
17
  import stable_whisper
@@ -119,17 +119,15 @@ class Task:
119
  self.t_s = time()
120
  # self.SRT_Script = SrtScript
121
 
122
- en_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_en.srt")
123
- if not Path.exists(en_srt_path):
124
  # extract script from audio
125
  logging.info("extract script from audio")
126
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
- logging.info("device: ", device)
128
-
129
- audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")
130
 
131
  if method == "api":
132
- with open(audio_path, 'rb') as audio_file:
133
  transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt")
134
  elif method == "stable":
135
  model = stable_whisper.load_model(whisper_model, device)
@@ -147,11 +145,9 @@ class Task:
147
  # after get the transcript, release the gpu resource
148
  torch.cuda.empty_cache()
149
 
150
- self.SRT_Script = SrtScript(transcript)
151
  # save the srt script to local
152
- self.SRT_Script.write_srt_file_src(en_srt_path)
153
- time.sleep(5)
154
- pass
155
 
156
  # Module 2: SRT preprocess: perform preprocess steps
157
  # TODO: multi-lang and multi-field support according to task_cfg
@@ -161,67 +157,67 @@ class Task:
161
  self.SRT_Script.form_whole_sentence()
162
  # self.SRT_Script.spell_check_term()
163
  self.SRT_Script.correct_with_force_term()
164
- processed_srt_path_en = str(Path(self.task_local_dir).with_suffix('')) + '_processed.srt'
165
- self.SRT_Script.write_srt_file_src(processed_srt_path_en)
166
 
167
  if self.output_type["subtitle"] == "ass":
168
  logging.info("write English .srt file to .ass")
169
- assSub_en = srt2ass(processed_srt_path_en)
170
- logging.info('ASS subtitle saved as: ' + assSub_en)
171
  self.script_input = self.SRT_Script.get_source_only()
172
  pass
173
 
174
  def update_translation_progress(self, new_progress):
175
  if self.progress == TaskStatus.TRANSLATING:
176
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
177
- time.sleep(5)
178
 
179
  # Module 3: perform srt translation
180
  def translation(self):
181
  logging.info("---------------------Start Translation--------------------")
182
- get_translation(self.srt,self.model, self.video_name, self.video_link)
183
- time.sleep(5)
184
- pass
185
 
186
  # Module 4: perform srt post process steps
187
- def postprocess(self, encode=False, srt_only=False):
188
  self.status = TaskStatus.POST_PROCESSING
189
 
190
  logging.info("---------------------Start Post-processing SRT class---------------------")
191
  self.SRT_Script.check_len_and_split()
192
  self.SRT_Script.remove_trans_punctuation()
 
 
 
 
 
 
 
 
 
 
193
 
194
- base_path = Path(self.dir_result).joinpath(self.video_name).joinpath(self.video_name)
 
 
 
 
195
 
196
- self.SRT_Script.write_srt_file_translate(f"{base_path}_zh.srt")
197
- self.SRT_Script.write_srt_file_bilingual(f"{base_path}_bi.srt")
 
 
198
 
199
- logging.info("write Chinese .srt file to .ass")
200
- assSub_zh = srt2ass(f"{base_path}_zh.srt", "default", "No", "Modest")
201
- logging.info('ASS subtitle saved as: ' + assSub_zh)
202
 
203
  # encode to .mp4 video file
204
- if encode:
205
  logging.info("encoding video file")
206
- if srt_only:
207
- subprocess.run(
208
- f'ffmpeg -i {self.video_path} -vf "subtitles={base_path}_zh.srt" {base_path}.mp4')
209
- else:
210
- subprocess.run(
211
- f'ffmpeg -i {self.video_path} -vf "subtitles={base_path}_zh.ass" {base_path}.mp4')
212
 
213
  self.t_e = time()
214
  logging.info(
215
  "Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s))))
216
-
217
-
218
- time.sleep(5)
219
- pass
220
-
221
- # Module 5: output module
222
- def output_render(self):
223
- self.status = TaskStatus.OUTPUT_MODULE
224
- return "TODO"
225
 
226
  def run_pipeline(self):
227
  self.get_srt_class()
@@ -229,6 +225,7 @@ class Task:
229
  self.translation()
230
  self.postprocess()
231
  self.result = self.output_render()
 
232
 
233
  class YoutubeTask(Task):
234
  def __init__(self, task_id, task_local_dir, task_cfg, youtube_url):
 
11
  from src.srt_util.srt import SrtScript
12
  from src.srt_util.srt2ass import srt2ass
13
  from time import time, strftime, gmtime, sleep
14
+ from src.translators.translation import get_translation, translate
15
 
16
  import torch
17
  import stable_whisper
 
119
  self.t_s = time()
120
  # self.SRT_Script = SrtScript
121
 
122
+ src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
123
+ if not Path.exists(src_srt_path):
124
  # extract script from audio
125
  logging.info("extract script from audio")
126
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ # logging.info("device: ", device)
 
 
128
 
129
  if method == "api":
130
+ with open(self.audio_path, 'rb') as audio_file:
131
  transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt")
132
  elif method == "stable":
133
  model = stable_whisper.load_model(whisper_model, device)
 
145
  # after get the transcript, release the gpu resource
146
  torch.cuda.empty_cache()
147
 
148
+ self.SRT_Script = SrtScript(transcript['segments'])
149
  # save the srt script to local
150
+ self.SRT_Script.write_srt_file_src(src_srt_path)
 
 
151
 
152
  # Module 2: SRT preprocess: perform preprocess steps
153
  # TODO: multi-lang and multi-field support according to task_cfg
 
157
  self.SRT_Script.form_whole_sentence()
158
  # self.SRT_Script.spell_check_term()
159
  self.SRT_Script.correct_with_force_term()
160
+ processed_srt_path_src = str(Path(self.task_local_dir) / f'{self.task_id}_processed.srt')
161
+ self.SRT_Script.write_srt_file_src(processed_srt_path_src)
162
 
163
  if self.output_type["subtitle"] == "ass":
164
  logging.info("write English .srt file to .ass")
165
+ assSub_src = srt2ass(processed_srt_path_src)
166
+ logging.info('ASS subtitle saved as: ' + assSub_src)
167
  self.script_input = self.SRT_Script.get_source_only()
168
  pass
169
 
170
  def update_translation_progress(self, new_progress):
171
  if self.progress == TaskStatus.TRANSLATING:
172
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
 
173
 
174
  # Module 3: perform srt translation
175
  def translation(self):
176
  logging.info("---------------------Start Translation--------------------")
177
+ get_translation(self.SRT_Script, self.model, self.task_id)
 
 
178
 
179
  # Module 4: perform srt post process steps
180
+ def postprocess(self):
181
  self.status = TaskStatus.POST_PROCESSING
182
 
183
  logging.info("---------------------Start Post-processing SRT class---------------------")
184
  self.SRT_Script.check_len_and_split()
185
  self.SRT_Script.remove_trans_punctuation()
186
+ logging.info("---------------------Post-processing SRT class finished---------------------")
187
+
188
+ # Module 5: output module
189
+ def output_render(self):
190
+ self.status = TaskStatus.OUTPUT_MODULE
191
+ video_out = self.output_type["video"]
192
+ subtitle_type = self.output_type["subtitle"]
193
+ is_bilingal = self.output_type["bilingal"]
194
+
195
+ results_dir = Path(self.task_local_dir)/ "results"
196
 
197
+ subtitle_path = f"{results_dir}/{self.task_id}_{self.target_lang}.srt"
198
+ self.SRT_Script.write_srt_file_translate(subtitle_path)
199
+ if is_bilingal:
200
+ subtitle_path = f"{results_dir}/{self.task_id}_{self.source_lang}_{self.target_lang}.srt"
201
+ self.SRT_Script.write_srt_file_bilingual(subtitle_path)
202
 
203
+ if subtitle_type == "ass":
204
+ logging.info("write .srt file to .ass")
205
+ subtitle_path = srt2ass(subtitle_path, "default", "No", "Modest")
206
+ logging.info('ASS subtitle saved as: ' + subtitle_path)
207
 
208
+ final_res = subtitle_path
 
 
209
 
210
  # encode to .mp4 video file
211
+ if video_out and self.video_path is not None:
212
  logging.info("encoding video file")
213
+ subprocess.run(
214
+ f'ffmpeg -i {self.video_path} -vf "subtitles={subtitle_path}" {results_dir}/{self.task_id}.mp4')
215
+ final_res = f"{results_dir}/{self.task_id}.mp4"
 
 
 
216
 
217
  self.t_e = time()
218
  logging.info(
219
  "Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s))))
220
+ return final_res
 
 
 
 
 
 
 
 
221
 
222
  def run_pipeline(self):
223
  self.get_srt_class()
 
225
  self.translation()
226
  self.postprocess()
227
  self.result = self.output_render()
228
+ print(self.result)
229
 
230
  class YoutubeTask(Task):
231
  def __init__(self, task_id, task_local_dir, task_cfg, youtube_url):
src/{translation β†’ translators}/LLM_task.py RENAMED
File without changes
src/translators/__init__.py ADDED
File without changes
src/{translation β†’ translators}/translation.py RENAMED
@@ -3,11 +3,11 @@ import logging
3
  from time import sleep
4
  from tqdm import tqdm
5
  from src.srt_util.srt import split_script
6
- from LLM_task import LLM_task
7
 
8
- def get_translation(srt,model,video_name,video_link):
9
- script_arr, range_arr = split_script(srt)
10
- translate(srt, script_arr, range_arr, model, video_name, video_link)
11
  pass
12
 
13
  def check_translation(sentence, translation):
@@ -25,7 +25,7 @@ def check_translation(sentence, translation):
25
  return True
26
 
27
 
28
- def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count=5, task=None, temp = 0.15):
29
  """
30
  Translates the given script array into another language using the chatgpt and writes to the SRT file.
31
 
@@ -38,7 +38,6 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
38
  :param range_arr: A list of tuples representing the start and end positions of sentences in the script.
39
  :param model_name: The name of the translation model to be used.
40
  :param video_name: The name of the video.
41
- :param video_link: The link to the video.
42
  :param attempts_count: Number of attemps of failures for unmatched sentences.
43
  :param task: Prompt.
44
  :param temp: Model temperature.
@@ -60,10 +59,10 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
60
  while flag:
61
  flag = False
62
  try:
63
- translate = LLM_task(model_name, sentence)
64
  # detect merge sentence issue and try to solve for five times:
65
  while not check_translation(sentence, translate) and attempts_count > 0:
66
- translate = LLM_task(model_name,sentence,task,temp)
67
  attempts_count -= 1
68
 
69
  # if failure still happen, split into smaller tokens
@@ -85,4 +84,4 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
85
  sleep(30)
86
  flag = True
87
 
88
- srt.set_translation(translate, range_, model_name, video_name, video_link)
 
3
  from time import sleep
4
  from tqdm import tqdm
5
  from src.srt_util.srt import split_script
6
+ from .LLM_task import LLM_task
7
 
8
+ def get_translation(srt, model, video_name):
9
+ script_arr, range_arr = split_script(srt.get_source_only())
10
+ translate(srt, script_arr, range_arr, model, video_name)
11
  pass
12
 
13
  def check_translation(sentence, translation):
 
25
  return True
26
 
27
 
28
+ def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
29
  """
30
  Translates the given script array into another language using the chatgpt and writes to the SRT file.
31
 
 
38
  :param range_arr: A list of tuples representing the start and end positions of sentences in the script.
39
  :param model_name: The name of the translation model to be used.
40
  :param video_name: The name of the video.
 
41
  :param attempts_count: Number of attemps of failures for unmatched sentences.
42
  :param task: Prompt.
43
  :param temp: Model temperature.
 
59
  while flag:
60
  flag = False
61
  try:
62
+ translate = LLM_task(model_name, sentence, task, temp)
63
  # detect merge sentence issue and try to solve for five times:
64
  while not check_translation(sentence, translate) and attempts_count > 0:
65
+ translate = LLM_task(model_name, sentence, task, temp)
66
  attempts_count -= 1
67
 
68
  # if failure still happen, split into smaller tokens
 
84
  sleep(30)
85
  flag = True
86
 
87
+ srt.set_translation(translate, range_, model_name, video_name)