Spaces:
Sleeping
Sleeping
Eason Lu
commited on
Commit
·
cd67dcd
1
Parent(s):
f144427
debug piepline
Browse filesFormer-commit-id: 651fae66447937370d60073672ac834b0b2481de
configs/task_config.yaml
CHANGED
@@ -8,3 +8,4 @@ output_type:
|
|
8 |
source_lang: EN
|
9 |
target_lang: ZH
|
10 |
field: SC2
|
|
|
|
8 |
source_lang: EN
|
9 |
target_lang: ZH
|
10 |
field: SC2
|
11 |
+
chunk_size: 1000
|
src/task.py
CHANGED
@@ -11,7 +11,7 @@ import subprocess
|
|
11 |
from src.srt_util.srt import SrtScript
|
12 |
from src.srt_util.srt2ass import srt2ass
|
13 |
from time import time, strftime, gmtime, sleep
|
14 |
-
from
|
15 |
|
16 |
import torch
|
17 |
import stable_whisper
|
@@ -119,17 +119,15 @@ class Task:
|
|
119 |
self.t_s = time()
|
120 |
# self.SRT_Script = SrtScript
|
121 |
|
122 |
-
|
123 |
-
if not Path.exists(
|
124 |
# extract script from audio
|
125 |
logging.info("extract script from audio")
|
126 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
127 |
-
logging.info("device: ", device)
|
128 |
-
|
129 |
-
audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")
|
130 |
|
131 |
if method == "api":
|
132 |
-
with open(audio_path, 'rb') as audio_file:
|
133 |
transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt")
|
134 |
elif method == "stable":
|
135 |
model = stable_whisper.load_model(whisper_model, device)
|
@@ -147,11 +145,9 @@ class Task:
|
|
147 |
# after get the transcript, release the gpu resource
|
148 |
torch.cuda.empty_cache()
|
149 |
|
150 |
-
self.SRT_Script = SrtScript(transcript)
|
151 |
# save the srt script to local
|
152 |
-
self.SRT_Script.write_srt_file_src(
|
153 |
-
time.sleep(5)
|
154 |
-
pass
|
155 |
|
156 |
# Module 2: SRT preprocess: perform preprocess steps
|
157 |
# TODO: multi-lang and multi-field support according to task_cfg
|
@@ -161,67 +157,67 @@ class Task:
|
|
161 |
self.SRT_Script.form_whole_sentence()
|
162 |
# self.SRT_Script.spell_check_term()
|
163 |
self.SRT_Script.correct_with_force_term()
|
164 |
-
|
165 |
-
self.SRT_Script.write_srt_file_src(
|
166 |
|
167 |
if self.output_type["subtitle"] == "ass":
|
168 |
logging.info("write English .srt file to .ass")
|
169 |
-
|
170 |
-
logging.info('ASS subtitle saved as: ' +
|
171 |
self.script_input = self.SRT_Script.get_source_only()
|
172 |
pass
|
173 |
|
174 |
def update_translation_progress(self, new_progress):
|
175 |
if self.progress == TaskStatus.TRANSLATING:
|
176 |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
177 |
-
time.sleep(5)
|
178 |
|
179 |
# Module 3: perform srt translation
|
180 |
def translation(self):
|
181 |
logging.info("---------------------Start Translation--------------------")
|
182 |
-
get_translation(self.
|
183 |
-
time.sleep(5)
|
184 |
-
pass
|
185 |
|
186 |
# Module 4: perform srt post process steps
|
187 |
-
def postprocess(self
|
188 |
self.status = TaskStatus.POST_PROCESSING
|
189 |
|
190 |
logging.info("---------------------Start Post-processing SRT class---------------------")
|
191 |
self.SRT_Script.check_len_and_split()
|
192 |
self.SRT_Script.remove_trans_punctuation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
|
|
|
|
|
|
|
|
195 |
|
196 |
-
|
197 |
-
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
assSub_zh = srt2ass(f"{base_path}_zh.srt", "default", "No", "Modest")
|
201 |
-
logging.info('ASS subtitle saved as: ' + assSub_zh)
|
202 |
|
203 |
# encode to .mp4 video file
|
204 |
-
if
|
205 |
logging.info("encoding video file")
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
else:
|
210 |
-
subprocess.run(
|
211 |
-
f'ffmpeg -i {self.video_path} -vf "subtitles={base_path}_zh.ass" {base_path}.mp4')
|
212 |
|
213 |
self.t_e = time()
|
214 |
logging.info(
|
215 |
"Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s))))
|
216 |
-
|
217 |
-
|
218 |
-
time.sleep(5)
|
219 |
-
pass
|
220 |
-
|
221 |
-
# Module 5: output module
|
222 |
-
def output_render(self):
|
223 |
-
self.status = TaskStatus.OUTPUT_MODULE
|
224 |
-
return "TODO"
|
225 |
|
226 |
def run_pipeline(self):
|
227 |
self.get_srt_class()
|
@@ -229,6 +225,7 @@ class Task:
|
|
229 |
self.translation()
|
230 |
self.postprocess()
|
231 |
self.result = self.output_render()
|
|
|
232 |
|
233 |
class YoutubeTask(Task):
|
234 |
def __init__(self, task_id, task_local_dir, task_cfg, youtube_url):
|
|
|
11 |
from src.srt_util.srt import SrtScript
|
12 |
from src.srt_util.srt2ass import srt2ass
|
13 |
from time import time, strftime, gmtime, sleep
|
14 |
+
from src.translators.translation import get_translation, translate
|
15 |
|
16 |
import torch
|
17 |
import stable_whisper
|
|
|
119 |
self.t_s = time()
|
120 |
# self.SRT_Script = SrtScript
|
121 |
|
122 |
+
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
|
123 |
+
if not Path.exists(src_srt_path):
|
124 |
# extract script from audio
|
125 |
logging.info("extract script from audio")
|
126 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
127 |
+
# logging.info("device: ", device)
|
|
|
|
|
128 |
|
129 |
if method == "api":
|
130 |
+
with open(self.audio_path, 'rb') as audio_file:
|
131 |
transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt")
|
132 |
elif method == "stable":
|
133 |
model = stable_whisper.load_model(whisper_model, device)
|
|
|
145 |
# after get the transcript, release the gpu resource
|
146 |
torch.cuda.empty_cache()
|
147 |
|
148 |
+
self.SRT_Script = SrtScript(transcript['segments'])
|
149 |
# save the srt script to local
|
150 |
+
self.SRT_Script.write_srt_file_src(src_srt_path)
|
|
|
|
|
151 |
|
152 |
# Module 2: SRT preprocess: perform preprocess steps
|
153 |
# TODO: multi-lang and multi-field support according to task_cfg
|
|
|
157 |
self.SRT_Script.form_whole_sentence()
|
158 |
# self.SRT_Script.spell_check_term()
|
159 |
self.SRT_Script.correct_with_force_term()
|
160 |
+
processed_srt_path_src = str(Path(self.task_local_dir) / f'{self.task_id}_processed.srt')
|
161 |
+
self.SRT_Script.write_srt_file_src(processed_srt_path_src)
|
162 |
|
163 |
if self.output_type["subtitle"] == "ass":
|
164 |
logging.info("write English .srt file to .ass")
|
165 |
+
assSub_src = srt2ass(processed_srt_path_src)
|
166 |
+
logging.info('ASS subtitle saved as: ' + assSub_src)
|
167 |
self.script_input = self.SRT_Script.get_source_only()
|
168 |
pass
|
169 |
|
170 |
def update_translation_progress(self, new_progress):
|
171 |
if self.progress == TaskStatus.TRANSLATING:
|
172 |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
|
|
173 |
|
174 |
# Module 3: perform srt translation
|
175 |
def translation(self):
|
176 |
logging.info("---------------------Start Translation--------------------")
|
177 |
+
get_translation(self.SRT_Script, self.model, self.task_id)
|
|
|
|
|
178 |
|
179 |
# Module 4: perform srt post process steps
|
180 |
+
def postprocess(self):
|
181 |
self.status = TaskStatus.POST_PROCESSING
|
182 |
|
183 |
logging.info("---------------------Start Post-processing SRT class---------------------")
|
184 |
self.SRT_Script.check_len_and_split()
|
185 |
self.SRT_Script.remove_trans_punctuation()
|
186 |
+
logging.info("---------------------Post-processing SRT class finished---------------------")
|
187 |
+
|
188 |
+
# Module 5: output module
|
189 |
+
def output_render(self):
|
190 |
+
self.status = TaskStatus.OUTPUT_MODULE
|
191 |
+
video_out = self.output_type["video"]
|
192 |
+
subtitle_type = self.output_type["subtitle"]
|
193 |
+
is_bilingal = self.output_type["bilingal"]
|
194 |
+
|
195 |
+
results_dir = Path(self.task_local_dir)/ "results"
|
196 |
|
197 |
+
subtitle_path = f"{results_dir}/{self.task_id}_{self.target_lang}.srt"
|
198 |
+
self.SRT_Script.write_srt_file_translate(subtitle_path)
|
199 |
+
if is_bilingal:
|
200 |
+
subtitle_path = f"{results_dir}/{self.task_id}_{self.source_lang}_{self.target_lang}.srt"
|
201 |
+
self.SRT_Script.write_srt_file_bilingual(subtitle_path)
|
202 |
|
203 |
+
if subtitle_type == "ass":
|
204 |
+
logging.info("write .srt file to .ass")
|
205 |
+
subtitle_path = srt2ass(subtitle_path, "default", "No", "Modest")
|
206 |
+
logging.info('ASS subtitle saved as: ' + subtitle_path)
|
207 |
|
208 |
+
final_res = subtitle_path
|
|
|
|
|
209 |
|
210 |
# encode to .mp4 video file
|
211 |
+
if video_out and self.video_path is not None:
|
212 |
logging.info("encoding video file")
|
213 |
+
subprocess.run(
|
214 |
+
f'ffmpeg -i {self.video_path} -vf "subtitles={subtitle_path}" {results_dir}/{self.task_id}.mp4')
|
215 |
+
final_res = f"{results_dir}/{self.task_id}.mp4"
|
|
|
|
|
|
|
216 |
|
217 |
self.t_e = time()
|
218 |
logging.info(
|
219 |
"Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s))))
|
220 |
+
return final_res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
def run_pipeline(self):
|
223 |
self.get_srt_class()
|
|
|
225 |
self.translation()
|
226 |
self.postprocess()
|
227 |
self.result = self.output_render()
|
228 |
+
print(self.result)
|
229 |
|
230 |
class YoutubeTask(Task):
|
231 |
def __init__(self, task_id, task_local_dir, task_cfg, youtube_url):
|
src/{translation → translators}/LLM_task.py
RENAMED
File without changes
|
src/translators/__init__.py
ADDED
File without changes
|
src/{translation → translators}/translation.py
RENAMED
@@ -3,11 +3,11 @@ import logging
|
|
3 |
from time import sleep
|
4 |
from tqdm import tqdm
|
5 |
from src.srt_util.srt import split_script
|
6 |
-
from LLM_task import LLM_task
|
7 |
|
8 |
-
def get_translation(srt,model,video_name
|
9 |
-
script_arr, range_arr = split_script(srt)
|
10 |
-
translate(srt, script_arr, range_arr, model, video_name
|
11 |
pass
|
12 |
|
13 |
def check_translation(sentence, translation):
|
@@ -25,7 +25,7 @@ def check_translation(sentence, translation):
|
|
25 |
return True
|
26 |
|
27 |
|
28 |
-
def translate(srt, script_arr, range_arr, model_name, video_name,
|
29 |
"""
|
30 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
31 |
|
@@ -38,7 +38,6 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
38 |
:param range_arr: A list of tuples representing the start and end positions of sentences in the script.
|
39 |
:param model_name: The name of the translation model to be used.
|
40 |
:param video_name: The name of the video.
|
41 |
-
:param video_link: The link to the video.
|
42 |
:param attempts_count: Number of attemps of failures for unmatched sentences.
|
43 |
:param task: Prompt.
|
44 |
:param temp: Model temperature.
|
@@ -60,10 +59,10 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
60 |
while flag:
|
61 |
flag = False
|
62 |
try:
|
63 |
-
translate = LLM_task(model_name, sentence)
|
64 |
# detect merge sentence issue and try to solve for five times:
|
65 |
while not check_translation(sentence, translate) and attempts_count > 0:
|
66 |
-
translate = LLM_task(model_name,sentence,task,temp)
|
67 |
attempts_count -= 1
|
68 |
|
69 |
# if failure still happen, split into smaller tokens
|
@@ -85,4 +84,4 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
85 |
sleep(30)
|
86 |
flag = True
|
87 |
|
88 |
-
srt.set_translation(translate, range_, model_name, video_name
|
|
|
3 |
from time import sleep
|
4 |
from tqdm import tqdm
|
5 |
from src.srt_util.srt import split_script
|
6 |
+
from .LLM_task import LLM_task
|
7 |
|
8 |
+
def get_translation(srt, model, video_name):
|
9 |
+
script_arr, range_arr = split_script(srt.get_source_only())
|
10 |
+
translate(srt, script_arr, range_arr, model, video_name)
|
11 |
pass
|
12 |
|
13 |
def check_translation(sentence, translation):
|
|
|
25 |
return True
|
26 |
|
27 |
|
28 |
+
def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
|
29 |
"""
|
30 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
31 |
|
|
|
38 |
:param range_arr: A list of tuples representing the start and end positions of sentences in the script.
|
39 |
:param model_name: The name of the translation model to be used.
|
40 |
:param video_name: The name of the video.
|
|
|
41 |
:param attempts_count: Number of attemps of failures for unmatched sentences.
|
42 |
:param task: Prompt.
|
43 |
:param temp: Model temperature.
|
|
|
59 |
while flag:
|
60 |
flag = False
|
61 |
try:
|
62 |
+
translate = LLM_task(model_name, sentence, task, temp)
|
63 |
# detect merge sentence issue and try to solve for five times:
|
64 |
while not check_translation(sentence, translate) and attempts_count > 0:
|
65 |
+
translate = LLM_task(model_name, sentence, task, temp)
|
66 |
attempts_count -= 1
|
67 |
|
68 |
# if failure still happen, split into smaller tokens
|
|
|
84 |
sleep(30)
|
85 |
flag = True
|
86 |
|
87 |
+
srt.set_translation(translate, range_, model_name, video_name)
|