Spaces:
Sleeping
Sleeping
Eason Lu
commited on
Commit
·
b37d0d4
1
Parent(s):
6808a65
debugs
Browse filesFormer-commit-id: 2d7d950f54b1deb8b4dd9b68c98d65384954c47e
- configs/task_config.yaml +26 -9
- entries/run.py +10 -2
- src/srt_util/srt.py +1 -7
- src/task.py +37 -35
- src/translators/translation.py +5 -5
configs/task_config.yaml
CHANGED
@@ -1,18 +1,35 @@
|
|
1 |
# configuration for each task
|
2 |
-
model: gpt-4
|
3 |
-
# output type that user receive
|
4 |
-
output_type:
|
5 |
-
subtitle: srt
|
6 |
-
video: False
|
7 |
-
bilingal: False
|
8 |
source_lang: EN
|
9 |
target_lang: ZH
|
10 |
field: SC2
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
pre_process:
|
13 |
-
ON: True
|
14 |
sentence_form: True
|
15 |
spell_check: False
|
16 |
term_correct: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
post_process:
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# configuration for each task
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
source_lang: EN
|
3 |
target_lang: ZH
|
4 |
field: SC2
|
5 |
+
|
6 |
+
# ASR config
|
7 |
+
ASR:
|
8 |
+
ASR_model: whisper
|
9 |
+
whisper_config:
|
10 |
+
whisper_model: tiny
|
11 |
+
method: stable
|
12 |
+
|
13 |
+
# pre-process module config
|
14 |
pre_process:
|
|
|
15 |
sentence_form: True
|
16 |
spell_check: False
|
17 |
term_correct: True
|
18 |
+
|
19 |
+
# Translation module config
|
20 |
+
translation:
|
21 |
+
model: gpt-4
|
22 |
+
chunk_size: 1000
|
23 |
+
|
24 |
+
# post-process module config
|
25 |
post_process:
|
26 |
+
check_len_and_split: True
|
27 |
+
remove_trans_punctuation: True
|
28 |
+
|
29 |
+
# output type that user receive
|
30 |
+
output_type:
|
31 |
+
subtitle: srt
|
32 |
+
video: False
|
33 |
+
bilingal: False
|
34 |
+
|
35 |
+
|
entries/run.py
CHANGED
@@ -10,6 +10,13 @@ from datetime import datetime
|
|
10 |
import shutil
|
11 |
from uuid import uuid4
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def parse_args():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
@@ -42,8 +49,9 @@ if __name__ == "__main__":
|
|
42 |
task_dir.mkdir(parents=False, exist_ok=False)
|
43 |
task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
|
44 |
|
45 |
-
# logging
|
46 |
-
|
|
|
47 |
logging.FileHandler(
|
48 |
"{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
|
49 |
'w', encoding='utf-8')])
|
|
|
10 |
import shutil
|
11 |
from uuid import uuid4
|
12 |
|
13 |
+
"""
|
14 |
+
Main entry for terminal environment.
|
15 |
+
Use it for debug and development purpose.
|
16 |
+
Usage: python3 entries/run.py [-h] [--link LINK] [--video_file VIDEO_FILE] [--audio_file AUDIO_FILE] [--srt_file SRT_FILE] [--continue CONTINUE]
|
17 |
+
[--launch_cfg LAUNCH_CFG] [--task_cfg TASK_CFG]
|
18 |
+
"""
|
19 |
+
|
20 |
def parse_args():
|
21 |
parser = argparse.ArgumentParser()
|
22 |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
|
|
49 |
task_dir.mkdir(parents=False, exist_ok=False)
|
50 |
task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
|
51 |
|
52 |
+
# logging setting
|
53 |
+
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
54 |
+
logging.basicConfig(level=logging.INFO, format=logfmt, handlers=[
|
55 |
logging.FileHandler(
|
56 |
"{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
|
57 |
'w', encoding='utf-8')])
|
src/srt_util/srt.py
CHANGED
@@ -185,7 +185,6 @@ class SrtScript(object):
|
|
185 |
|
186 |
def inner_func(target, input_str):
|
187 |
response = openai.ChatCompletion.create(
|
188 |
-
# model=model,
|
189 |
model="gpt-4",
|
190 |
messages=[
|
191 |
{"role": "system",
|
@@ -208,19 +207,13 @@ class SrtScript(object):
|
|
208 |
flag = True
|
209 |
while flag:
|
210 |
flag = False
|
211 |
-
# print("translate:")
|
212 |
-
# print(translate)
|
213 |
try:
|
214 |
-
# print("target")
|
215 |
-
# print(end_seg_id - start_seg_id + 1)
|
216 |
translate = inner_func(end_seg_id - start_seg_id + 1, translate)
|
217 |
except Exception as e:
|
218 |
print("An error has occurred during solving unmatched lines:", e)
|
219 |
print("Retrying...")
|
220 |
flag = True
|
221 |
lines = translate.split('\n')
|
222 |
-
# print("result")
|
223 |
-
# print(len(lines))
|
224 |
|
225 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
226 |
solved = False
|
@@ -264,6 +257,7 @@ class SrtScript(object):
|
|
264 |
# evenly split seg to 2 parts and add new seg into self.segments
|
265 |
|
266 |
# ignore the initial comma to solve the recursion problem
|
|
|
267 |
if len(seg.source_text) > 2:
|
268 |
if seg.source_text[:2] == ', ':
|
269 |
seg.source_text = seg.source_text[2:]
|
|
|
185 |
|
186 |
def inner_func(target, input_str):
|
187 |
response = openai.ChatCompletion.create(
|
|
|
188 |
model="gpt-4",
|
189 |
messages=[
|
190 |
{"role": "system",
|
|
|
207 |
flag = True
|
208 |
while flag:
|
209 |
flag = False
|
|
|
|
|
210 |
try:
|
|
|
|
|
211 |
translate = inner_func(end_seg_id - start_seg_id + 1, translate)
|
212 |
except Exception as e:
|
213 |
print("An error has occurred during solving unmatched lines:", e)
|
214 |
print("Retrying...")
|
215 |
flag = True
|
216 |
lines = translate.split('\n')
|
|
|
|
|
217 |
|
218 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
219 |
solved = False
|
|
|
257 |
# evenly split seg to 2 parts and add new seg into self.segments
|
258 |
|
259 |
# ignore the initial comma to solve the recursion problem
|
260 |
+
# FIXME: accomodate multilingual setting
|
261 |
if len(seg.source_text) > 2:
|
262 |
if seg.source_text[:2] == ', ':
|
263 |
seg.source_text = seg.source_text[2:]
|
src/task.py
CHANGED
@@ -55,7 +55,6 @@ class TaskStatus(str, Enum):
|
|
55 |
OUTPUT_MODULE = 'OUTPUT_MODULE'
|
56 |
|
57 |
|
58 |
-
|
59 |
class Task:
|
60 |
@property
|
61 |
def status(self):
|
@@ -70,69 +69,74 @@ class Task:
|
|
70 |
def __init__(self, task_id, task_local_dir, task_cfg):
|
71 |
self.__status_lock = threading.Lock()
|
72 |
self.__status = TaskStatus.CREATED
|
|
|
73 |
openai.api_key = getenv("OPENAI_API_KEY")
|
74 |
-
self.
|
|
|
75 |
self.task_local_dir = task_local_dir
|
76 |
-
self.
|
77 |
-
self.
|
|
|
|
|
78 |
self.output_type = task_cfg["output_type"]
|
79 |
self.target_lang = task_cfg["target_lang"]
|
80 |
self.source_lang = task_cfg["source_lang"]
|
81 |
self.field = task_cfg["field"]
|
82 |
self.pre_setting = task_cfg["pre_process"]
|
83 |
self.post_setting = task_cfg["post_process"]
|
84 |
-
|
85 |
self.audio_path = None
|
86 |
self.SRT_Script = None
|
87 |
self.result = None
|
88 |
self.s_t = None
|
89 |
self.t_e = None
|
90 |
|
91 |
-
print(f"
|
92 |
-
logging.info(f"
|
93 |
-
logging.info(f"
|
94 |
-
logging.info(f" Model:
|
95 |
-
logging.info(f"
|
96 |
-
logging.info(f"
|
97 |
-
logging.info(f"
|
98 |
-
logging.info("
|
99 |
-
for key
|
100 |
-
logging.info(f"
|
101 |
-
logging.info("
|
102 |
-
for key
|
103 |
-
logging.info(f"
|
104 |
|
105 |
@staticmethod
|
106 |
def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg):
|
107 |
# convert to audio
|
108 |
-
logging.info("
|
109 |
return YoutubeTask(task_id, task_dir, task_cfg, youtube_url)
|
110 |
|
111 |
@staticmethod
|
112 |
def fromAudioFile(audio_path, task_id, task_dir, task_cfg):
|
113 |
# get audio path
|
114 |
-
logging.info("
|
115 |
return AudioTask(task_id, task_dir, task_cfg, audio_path)
|
116 |
|
117 |
@staticmethod
|
118 |
def fromVideoFile(video_path, task_id, task_dir, task_cfg):
|
119 |
# get audio path
|
120 |
-
logging.info("
|
121 |
return VideoTask(task_id, task_dir, task_cfg, video_path)
|
122 |
|
123 |
# Module 1 ASR: audio --> SRT_script
|
124 |
-
def get_srt_class(self
|
125 |
# Instead of using the script_en variable directly, we'll use script_input
|
|
|
126 |
self.status = TaskStatus.INITIALIZING_ASR
|
127 |
self.t_s = time()
|
128 |
# self.SRT_Script = SrtScript
|
129 |
-
|
|
|
130 |
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
|
131 |
if not Path.exists(src_srt_path):
|
132 |
# extract script from audio
|
133 |
logging.info("extract script from audio")
|
134 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
135 |
-
# logging.info("device: ", device)
|
136 |
|
137 |
if method == "api":
|
138 |
with open(self.audio_path, 'rb') as audio_file:
|
@@ -158,7 +162,6 @@ class Task:
|
|
158 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
159 |
|
160 |
# Module 2: SRT preprocess: perform preprocess steps
|
161 |
-
# TODO: multi-lang and multi-field support according to task_cfg
|
162 |
def preprocess(self):
|
163 |
self.status = TaskStatus.PRE_PROCESSING
|
164 |
logging.info("--------------------Start Preprocessing SRT class--------------------")
|
@@ -183,18 +186,20 @@ class Task:
|
|
183 |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
184 |
|
185 |
# Module 3: perform srt translation
|
186 |
-
def translation(self
|
187 |
logging.info("---------------------Start Translation--------------------")
|
188 |
-
prompt = prompt_selector(self.source_lang,self.target_lang,
|
189 |
-
get_translation(self.SRT_Script, self.
|
190 |
|
191 |
# Module 4: perform srt post process steps
|
192 |
def postprocess(self):
|
193 |
self.status = TaskStatus.POST_PROCESSING
|
194 |
|
195 |
logging.info("---------------------Start Post-processing SRT class---------------------")
|
196 |
-
self.
|
197 |
-
|
|
|
|
|
198 |
logging.info("---------------------Post-processing SRT class finished---------------------")
|
199 |
|
200 |
# Module 5: output module
|
@@ -233,11 +238,9 @@ class Task:
|
|
233 |
|
234 |
def run_pipeline(self):
|
235 |
self.get_srt_class()
|
236 |
-
|
237 |
-
self.preprocess()
|
238 |
self.translation()
|
239 |
-
|
240 |
-
self.postprocess()
|
241 |
self.result = self.output_render()
|
242 |
print(self.result)
|
243 |
|
@@ -259,7 +262,6 @@ class YoutubeTask(Task):
|
|
259 |
audio = yt.streams.filter(only_audio=True).first()
|
260 |
if audio:
|
261 |
audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3")
|
262 |
-
# logging.info(f'Audio download completed to {self.task_local_dir}!')
|
263 |
else:
|
264 |
logging.info(" download audio failed, using ffmpeg to extract audio")
|
265 |
subprocess.run(
|
|
|
55 |
OUTPUT_MODULE = 'OUTPUT_MODULE'
|
56 |
|
57 |
|
|
|
58 |
class Task:
|
59 |
@property
|
60 |
def status(self):
|
|
|
69 |
def __init__(self, task_id, task_local_dir, task_cfg):
|
70 |
self.__status_lock = threading.Lock()
|
71 |
self.__status = TaskStatus.CREATED
|
72 |
+
self.gpu_status = 0
|
73 |
openai.api_key = getenv("OPENAI_API_KEY")
|
74 |
+
self.task_id = task_id
|
75 |
+
|
76 |
self.task_local_dir = task_local_dir
|
77 |
+
self.ASR_setting = task_cfg["ASR"]
|
78 |
+
self.translation_setting = task_cfg["translation"]
|
79 |
+
self.translation_model = self.translation_setting["model"]
|
80 |
+
|
81 |
self.output_type = task_cfg["output_type"]
|
82 |
self.target_lang = task_cfg["target_lang"]
|
83 |
self.source_lang = task_cfg["source_lang"]
|
84 |
self.field = task_cfg["field"]
|
85 |
self.pre_setting = task_cfg["pre_process"]
|
86 |
self.post_setting = task_cfg["post_process"]
|
87 |
+
|
88 |
self.audio_path = None
|
89 |
self.SRT_Script = None
|
90 |
self.result = None
|
91 |
self.s_t = None
|
92 |
self.t_e = None
|
93 |
|
94 |
+
print(f"Task ID: {self.task_id}")
|
95 |
+
logging.info(f"Task ID: {self.task_id}")
|
96 |
+
logging.info(f"{self.source_lang} -> {self.target_lang} task in {self.field}")
|
97 |
+
logging.info(f"Translation Model: {self.translation_model}")
|
98 |
+
logging.info(f"subtitle_type: {self.output_type['subtitle']}")
|
99 |
+
logging.info(f"video_ouput: {self.output_type['video']}")
|
100 |
+
logging.info(f"bilingal_ouput: {self.output_type['bilingal']}")
|
101 |
+
logging.info("Pre-process setting:")
|
102 |
+
for key in self.pre_setting:
|
103 |
+
logging.info(f"{key}: {self.pre_setting[key]}")
|
104 |
+
logging.info("Post-process setting:")
|
105 |
+
for key in self.post_setting:
|
106 |
+
logging.info(f"{key}: {self.post_setting[key]}")
|
107 |
|
108 |
@staticmethod
|
109 |
def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg):
|
110 |
# convert to audio
|
111 |
+
logging.info("Task Creation method: Youtube Link")
|
112 |
return YoutubeTask(task_id, task_dir, task_cfg, youtube_url)
|
113 |
|
114 |
@staticmethod
|
115 |
def fromAudioFile(audio_path, task_id, task_dir, task_cfg):
|
116 |
# get audio path
|
117 |
+
logging.info("Task Creation method: Audio File")
|
118 |
return AudioTask(task_id, task_dir, task_cfg, audio_path)
|
119 |
|
120 |
@staticmethod
|
121 |
def fromVideoFile(video_path, task_id, task_dir, task_cfg):
|
122 |
# get audio path
|
123 |
+
logging.info("Task Creation method: Video File")
|
124 |
return VideoTask(task_id, task_dir, task_cfg, video_path)
|
125 |
|
126 |
# Module 1 ASR: audio --> SRT_script
|
127 |
+
def get_srt_class(self):
|
128 |
# Instead of using the script_en variable directly, we'll use script_input
|
129 |
+
# TODO: setup ASR module like translator
|
130 |
self.status = TaskStatus.INITIALIZING_ASR
|
131 |
self.t_s = time()
|
132 |
# self.SRT_Script = SrtScript
|
133 |
+
method = self.ASR_setting["whisper_config"]["method"]
|
134 |
+
whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
|
135 |
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
|
136 |
if not Path.exists(src_srt_path):
|
137 |
# extract script from audio
|
138 |
logging.info("extract script from audio")
|
139 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
140 |
|
141 |
if method == "api":
|
142 |
with open(self.audio_path, 'rb') as audio_file:
|
|
|
162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
163 |
|
164 |
# Module 2: SRT preprocess: perform preprocess steps
|
|
|
165 |
def preprocess(self):
|
166 |
self.status = TaskStatus.PRE_PROCESSING
|
167 |
logging.info("--------------------Start Preprocessing SRT class--------------------")
|
|
|
186 |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
187 |
|
188 |
# Module 3: perform srt translation
|
189 |
+
def translation(self):
|
190 |
logging.info("---------------------Start Translation--------------------")
|
191 |
+
prompt = prompt_selector(self.source_lang, self.target_lang, self.field)
|
192 |
+
get_translation(self.SRT_Script, self.translation_model, self.task_id, prompt, self.translation_setting['chunk_size'])
|
193 |
|
194 |
# Module 4: perform srt post process steps
|
195 |
def postprocess(self):
|
196 |
self.status = TaskStatus.POST_PROCESSING
|
197 |
|
198 |
logging.info("---------------------Start Post-processing SRT class---------------------")
|
199 |
+
if self.post_setting["check_len_and_split"]:
|
200 |
+
self.SRT_Script.check_len_and_split()
|
201 |
+
if self.post_setting["remove_trans_punctuation"]:
|
202 |
+
self.SRT_Script.remove_trans_punctuation()
|
203 |
logging.info("---------------------Post-processing SRT class finished---------------------")
|
204 |
|
205 |
# Module 5: output module
|
|
|
238 |
|
239 |
def run_pipeline(self):
|
240 |
self.get_srt_class()
|
241 |
+
self.preprocess()
|
|
|
242 |
self.translation()
|
243 |
+
self.postprocess()
|
|
|
244 |
self.result = self.output_render()
|
245 |
print(self.result)
|
246 |
|
|
|
262 |
audio = yt.streams.filter(only_audio=True).first()
|
263 |
if audio:
|
264 |
audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3")
|
|
|
265 |
else:
|
266 |
logging.info(" download audio failed, using ffmpeg to extract audio")
|
267 |
subprocess.run(
|
src/translators/translation.py
CHANGED
@@ -5,9 +5,9 @@ from tqdm import tqdm
|
|
5 |
from src.srt_util.srt import split_script
|
6 |
from .LLM_task import LLM_task
|
7 |
|
8 |
-
def get_translation(srt, model, video_name,
|
9 |
script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
|
10 |
-
translate(srt, script_arr, range_arr, model, video_name, task)
|
11 |
pass
|
12 |
|
13 |
def check_translation(sentence, translation):
|
@@ -39,7 +39,7 @@ def prompt_selector(src_lang, tgt_lang, domain):
|
|
39 |
"""
|
40 |
return prompt
|
41 |
|
42 |
-
def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
|
43 |
"""
|
44 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
45 |
|
@@ -61,14 +61,14 @@ def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count
|
|
61 |
raise Exception("Warning! No Input have passed to LLM!")
|
62 |
if task is None:
|
63 |
task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
|
64 |
-
|
65 |
previous_length = 0
|
66 |
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|
67 |
# update the range based on previous length
|
68 |
range_ = (range_[0] + previous_length, range_[1] + previous_length)
|
69 |
# using chatgpt model
|
70 |
print(f"now translating sentences {range_}")
|
71 |
-
|
72 |
flag = True
|
73 |
while flag:
|
74 |
flag = False
|
|
|
5 |
from src.srt_util.srt import split_script
|
6 |
from .LLM_task import LLM_task
|
7 |
|
8 |
+
def get_translation(srt, model, video_name, prompt, chunk_size = 1000):
|
9 |
script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
|
10 |
+
translate(srt, script_arr, range_arr, model, video_name, task=prompt)
|
11 |
pass
|
12 |
|
13 |
def check_translation(sentence, translation):
|
|
|
39 |
"""
|
40 |
return prompt
|
41 |
|
42 |
+
def translate(srt, script_arr, range_arr, model_name, video_name=None, attempts_count=5, task=None, temp = 0.15):
|
43 |
"""
|
44 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
45 |
|
|
|
61 |
raise Exception("Warning! No Input have passed to LLM!")
|
62 |
if task is None:
|
63 |
task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
|
64 |
+
logging.info(f"translation prompt: {task}")
|
65 |
previous_length = 0
|
66 |
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|
67 |
# update the range based on previous length
|
68 |
range_ = (range_[0] + previous_length, range_[1] + previous_length)
|
69 |
# using chatgpt model
|
70 |
print(f"now translating sentences {range_}")
|
71 |
+
logging.info(f"now translating sentences {range_}")
|
72 |
flag = True
|
73 |
while flag:
|
74 |
flag = False
|