Spaces:
Sleeping
Sleeping
Merge pull request #13 from project-kxkg/MergeFix
Browse filesMerge fix
Former-commit-id: 9554b4cceaef7d5a173731b003afddcd23eae48b
- SRT.py +54 -7
- pipeline.py +19 -7
SRT.py
CHANGED
@@ -2,6 +2,7 @@ from datetime import timedelta
|
|
2 |
from csv import reader
|
3 |
from datetime import datetime
|
4 |
import re
|
|
|
5 |
|
6 |
class SRT_segment(object):
|
7 |
def __init__(self, *args) -> None:
|
@@ -105,18 +106,64 @@ class SRT_script():
|
|
105 |
segments.append(self.merge_segs(idx_list))
|
106 |
|
107 |
self.segments = segments # need memory release?
|
|
|
|
|
108 |
|
109 |
def set_translation(self, translate:str, id_range:tuple):
|
110 |
start_seg_id = id_range[0]
|
111 |
end_seg_id = id_range[1]
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
lines = translate.split('\n\n')
|
114 |
-
if len(lines)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
121 |
# naive way to due with merge translation problem
|
122 |
# TODO: need a smarter solution
|
|
|
2 |
from csv import reader
|
3 |
from datetime import datetime
|
4 |
import re
|
5 |
+
import openai
|
6 |
|
7 |
class SRT_segment(object):
|
8 |
def __init__(self, *args) -> None:
|
|
|
106 |
segments.append(self.merge_segs(idx_list))
|
107 |
|
108 |
self.segments = segments # need memory release?
|
109 |
+
|
110 |
+
|
111 |
|
112 |
def set_translation(self, translate:str, id_range:tuple):
|
113 |
start_seg_id = id_range[0]
|
114 |
end_seg_id = id_range[1]
|
115 |
+
|
116 |
+
def inner_func(input_str):
|
117 |
+
response = openai.ChatCompletion.create(
|
118 |
+
model="gpt-3.5-turbo",
|
119 |
+
messages = [
|
120 |
+
{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
|
121 |
+
{"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
|
122 |
+
{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
|
123 |
+
{"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
|
124 |
+
],
|
125 |
+
temperature=0.7
|
126 |
+
)
|
127 |
+
return response['choices'][0]['message']['content'].strip()
|
128 |
+
|
129 |
lines = translate.split('\n\n')
|
130 |
+
if len(lines) < (end_seg_id - start_seg_id + 1):
|
131 |
+
count = 0
|
132 |
+
while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
133 |
+
|
134 |
+
count += 1
|
135 |
+
print("Solving Unmatched Lines|iteration {}".format(count))
|
136 |
+
input_str = "\n"
|
137 |
+
#initialize GPT input
|
138 |
+
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
139 |
+
input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
|
140 |
+
#Append to prompt string
|
141 |
+
#Adds sentence index let GPT keep track of sentence breaks
|
142 |
+
input_str += translate
|
143 |
+
#append translate to prompt
|
144 |
+
|
145 |
+
flag = True
|
146 |
+
while flag:
|
147 |
+
flag = False
|
148 |
+
try:
|
149 |
+
translate = inner_func(input_str)
|
150 |
+
except Exception as e:
|
151 |
+
print("An error has occurred during solving unmatched lines:",e)
|
152 |
+
print("Retrying...")
|
153 |
+
flag = True
|
154 |
+
|
155 |
+
lines = translate.split('\n\n')
|
156 |
+
if len(lines) < (end_seg_id - start_seg_id + 1):
|
157 |
+
print("Failed Solving unmatched lines, Manually parse needed")
|
158 |
+
|
159 |
+
print(lines)
|
160 |
+
#print(id_range)
|
161 |
+
#for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
162 |
+
# print(seg.source_text)
|
163 |
+
#print(translate)
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
168 |
# naive way to due with merge translation problem
|
169 |
# TODO: need a smarter solution
|
pipeline.py
CHANGED
@@ -163,10 +163,7 @@ def script_split(script_in, chunk_size = 1000):
|
|
163 |
|
164 |
script_arr, range_arr = script_split(script_input)
|
165 |
|
166 |
-
|
167 |
-
for s, range in tqdm(zip(script_arr, range_arr)):
|
168 |
-
# using chatgpt model
|
169 |
-
print(f"now translating sentences {range}")
|
170 |
if model_name == "gpt-3.5-turbo":
|
171 |
# print(s + "\n")
|
172 |
response = openai.ChatCompletion.create(
|
@@ -180,7 +177,7 @@ for s, range in tqdm(zip(script_arr, range_arr)):
|
|
180 |
temperature=0.15
|
181 |
)
|
182 |
|
183 |
-
|
184 |
|
185 |
if model_name == "text-davinci-003":
|
186 |
prompt = f"Please help me translate this into Chinese:\n\n{s}\n\n"
|
@@ -194,8 +191,23 @@ for s, range in tqdm(zip(script_arr, range_arr)):
|
|
194 |
frequency_penalty=0.0,
|
195 |
presence_penalty=0.0
|
196 |
)
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
srt.set_translation(translate, range)
|
200 |
|
201 |
srt.check_len_and_split()
|
|
|
163 |
|
164 |
script_arr, range_arr = script_split(script_input)
|
165 |
|
166 |
+
def get_response(model_name):
|
|
|
|
|
|
|
167 |
if model_name == "gpt-3.5-turbo":
|
168 |
# print(s + "\n")
|
169 |
response = openai.ChatCompletion.create(
|
|
|
177 |
temperature=0.15
|
178 |
)
|
179 |
|
180 |
+
return response['choices'][0]['message']['content'].strip()
|
181 |
|
182 |
if model_name == "text-davinci-003":
|
183 |
prompt = f"Please help me translate this into Chinese:\n\n{s}\n\n"
|
|
|
191 |
frequency_penalty=0.0,
|
192 |
presence_penalty=0.0
|
193 |
)
|
194 |
+
return response['choices'][0]['text'].strip()
|
195 |
+
pass
|
196 |
+
|
197 |
+
|
198 |
+
# Translate and save
|
199 |
+
for s, range in tqdm(zip(script_arr, range_arr)):
|
200 |
+
# using chatgpt model
|
201 |
+
print(f"now translating sentences {range}")
|
202 |
+
flag = True
|
203 |
+
while flag:
|
204 |
+
flag = False
|
205 |
+
try:
|
206 |
+
translate = get_response(model_name)
|
207 |
+
except Exception as e:
|
208 |
+
print("An error has occurred during translation:",e)
|
209 |
+
print("Retrying...")
|
210 |
+
flag = True
|
211 |
srt.set_translation(translate, range)
|
212 |
|
213 |
srt.check_len_and_split()
|