Macrodove commited on
Commit
4f0065c
1 Parent(s): e4c138e

implemented prompt selector

Browse files

Former-commit-id: d821ee6dd6ad1d75bf3d28cb2074a88ce5a11bf0

Files changed (2) hide show
  1. src/task.py +4 -3
  2. src/translators/translation.py +16 -6
src/task.py CHANGED
@@ -11,7 +11,7 @@ import subprocess
11
  from src.srt_util.srt import SrtScript
12
  from src.srt_util.srt2ass import srt2ass
13
  from time import time, strftime, gmtime, sleep
14
- from src.translators.translation import get_translation, translate
15
 
16
  import torch
17
  import stable_whisper
@@ -176,9 +176,10 @@ class Task:
176
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
177
 
178
  # Module 3: perform srt translation
179
- def translation(self):
180
  logging.info("---------------------Start Translation--------------------")
181
- get_translation(self.SRT_Script, self.model, self.task_id)
 
182
 
183
  # Module 4: perform srt post process steps
184
  def postprocess(self):
 
11
  from src.srt_util.srt import SrtScript
12
  from src.srt_util.srt2ass import srt2ass
13
  from time import time, strftime, gmtime, sleep
14
+ from src.translators.translation import get_translation, prompt_selector
15
 
16
  import torch
17
  import stable_whisper
 
176
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
177
 
178
  # Module 3: perform srt translation
179
+ def translation(self,task_cfg):
180
  logging.info("---------------------Start Translation--------------------")
181
+ prompt = prompt_selector(self.source_lang,self.target_lang,task_cfg['field'])
182
+ get_translation(self.SRT_Script, self.model, self.task_id, prompt, task_cfg['chunk_size'])
183
 
184
  # Module 4: perform srt post process steps
185
  def postprocess(self):
src/translators/translation.py CHANGED
@@ -5,9 +5,9 @@ from tqdm import tqdm
5
  from src.srt_util.srt import split_script
6
  from .LLM_task import LLM_task
7
 
8
- def get_translation(srt, model, video_name):
9
- script_arr, range_arr = split_script(srt.get_source_only())
10
- translate(srt, script_arr, range_arr, model, video_name)
11
  pass
12
 
13
  def check_translation(sentence, translation):
@@ -26,8 +26,18 @@ def check_translation(sentence, translation):
26
 
27
  # TODO{david}: prompts selector
28
  def prompt_selector(src_lang, tgt_lang, domain):
29
-
30
- return ""
 
 
 
 
 
 
 
 
 
 
31
 
32
  def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
33
  """
@@ -51,7 +61,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count
51
  raise Exception("Warning! No Input have passed to LLM!")
52
  if task is None:
53
  task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
54
-
55
  previous_length = 0
56
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
57
  # update the range based on previous length
 
5
  from src.srt_util.srt import split_script
6
  from .LLM_task import LLM_task
7
 
8
+ def get_translation(srt, model, video_name, task, chunk_size = 1000):
9
+ script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
10
+ translate(srt, script_arr, range_arr, model, video_name, task)
11
  pass
12
 
13
  def check_translation(sentence, translation):
 
26
 
27
  # TODO{david}: prompts selector
28
  def prompt_selector(src_lang, tgt_lang, domain):
29
+ language_map = {
30
+ "EN": "English",
31
+ "ZH": "Chinese",
32
+ }
33
+ src_lang = language_map[src_lang]
34
+ tgt_lang = language_map[tgt_lang]
35
+ prompt = f"""
36
+ you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
37
+ you will be provided with a segement in {[src_lang]} parsed by line, where your translation text should keep the original
38
+ meaning and the number of lines.
39
+ """
40
+ return prompt
41
 
42
  def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
43
  """
 
61
  raise Exception("Warning! No Input have passed to LLM!")
62
  if task is None:
63
  task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
64
+ print(task)
65
  previous_length = 0
66
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
67
  # update the range based on previous length