yichenl5 commited on
Commit
6051f87
2 Parent(s): b2ca465 ce7a58b

Merge pull request #13 from project-kxkg/MergeFix

Browse files

Merge fix

Former-commit-id: 9554b4cceaef7d5a173731b003afddcd23eae48b

Files changed (2) hide show
  1. SRT.py +54 -7
  2. pipeline.py +19 -7
SRT.py CHANGED
@@ -2,6 +2,7 @@ from datetime import timedelta
2
  from csv import reader
3
  from datetime import datetime
4
  import re
 
5
 
6
  class SRT_segment(object):
7
  def __init__(self, *args) -> None:
@@ -105,18 +106,64 @@ class SRT_script():
105
  segments.append(self.merge_segs(idx_list))
106
 
107
  self.segments = segments # need memory release?
 
 
108
 
109
  def set_translation(self, translate:str, id_range:tuple):
110
  start_seg_id = id_range[0]
111
  end_seg_id = id_range[1]
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  lines = translate.split('\n\n')
114
- if len(lines) != (end_seg_id - start_seg_id + 1):
115
- print(id_range)
116
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
117
- print(seg.source_text)
118
- print(translate)
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
121
  # naive way to due with merge translation problem
122
  # TODO: need a smarter solution
 
2
  from csv import reader
3
  from datetime import datetime
4
  import re
5
+ import openai
6
 
7
  class SRT_segment(object):
8
  def __init__(self, *args) -> None:
 
106
  segments.append(self.merge_segs(idx_list))
107
 
108
  self.segments = segments # need memory release?
109
+
110
+
111
 
112
  def set_translation(self, translate:str, id_range:tuple):
113
  start_seg_id = id_range[0]
114
  end_seg_id = id_range[1]
115
+
116
+ def inner_func(input_str):
117
+ response = openai.ChatCompletion.create(
118
+ model="gpt-3.5-turbo",
119
+ messages = [
120
+ {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
121
+ {"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
122
+ {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
123
+ {"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
124
+ ],
125
+ temperature=0.7
126
+ )
127
+ return response['choices'][0]['message']['content'].strip()
128
+
129
  lines = translate.split('\n\n')
130
+ if len(lines) < (end_seg_id - start_seg_id + 1):
131
+ count = 0
132
+ while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
133
+
134
+ count += 1
135
+ print("Solving Unmatched Lines|iteration {}".format(count))
136
+ input_str = "\n"
137
+ #initialize GPT input
138
+ for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
139
+ input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
140
+ #Append to prompt string
141
+ #Adds sentence index let GPT keep track of sentence breaks
142
+ input_str += translate
143
+ #append translate to prompt
144
+
145
+ flag = True
146
+ while flag:
147
+ flag = False
148
+ try:
149
+ translate = inner_func(input_str)
150
+ except Exception as e:
151
+ print("An error has occurred during solving unmatched lines:",e)
152
+ print("Retrying...")
153
+ flag = True
154
+
155
+ lines = translate.split('\n\n')
156
+ if len(lines) < (end_seg_id - start_seg_id + 1):
157
+ print("Failed Solving unmatched lines, Manually parse needed")
158
+
159
+ print(lines)
160
+ #print(id_range)
161
+ #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
162
+ # print(seg.source_text)
163
+ #print(translate)
164
+
165
+
166
+
167
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
168
  # naive way to due with merge translation problem
169
  # TODO: need a smarter solution
pipeline.py CHANGED
@@ -163,10 +163,7 @@ def script_split(script_in, chunk_size = 1000):
163
 
164
  script_arr, range_arr = script_split(script_input)
165
 
166
- # Translate and save
167
- for s, range in tqdm(zip(script_arr, range_arr)):
168
- # using chatgpt model
169
- print(f"now translating sentences {range}")
170
  if model_name == "gpt-3.5-turbo":
171
  # print(s + "\n")
172
  response = openai.ChatCompletion.create(
@@ -180,7 +177,7 @@ for s, range in tqdm(zip(script_arr, range_arr)):
180
  temperature=0.15
181
  )
182
 
183
- translate = response['choices'][0]['message']['content'].strip()
184
 
185
  if model_name == "text-davinci-003":
186
  prompt = f"Please help me translate this into Chinese:\n\n{s}\n\n"
@@ -194,8 +191,23 @@ for s, range in tqdm(zip(script_arr, range_arr)):
194
  frequency_penalty=0.0,
195
  presence_penalty=0.0
196
  )
197
- translate = response['choices'][0]['text'].strip()
198
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  srt.set_translation(translate, range)
200
 
201
  srt.check_len_and_split()
 
163
 
164
  script_arr, range_arr = script_split(script_input)
165
 
166
+ def get_response(model_name):
 
 
 
167
  if model_name == "gpt-3.5-turbo":
168
  # print(s + "\n")
169
  response = openai.ChatCompletion.create(
 
177
  temperature=0.15
178
  )
179
 
180
+ return response['choices'][0]['message']['content'].strip()
181
 
182
  if model_name == "text-davinci-003":
183
  prompt = f"Please help me translate this into Chinese:\n\n{s}\n\n"
 
191
  frequency_penalty=0.0,
192
  presence_penalty=0.0
193
  )
194
+ return response['choices'][0]['text'].strip()
195
+ pass
196
+
197
+
198
+ # Translate and save
199
+ for s, range in tqdm(zip(script_arr, range_arr)):
200
+ # using chatgpt model
201
+ print(f"now translating sentences {range}")
202
+ flag = True
203
+ while flag:
204
+ flag = False
205
+ try:
206
+ translate = get_response(model_name)
207
+ except Exception as e:
208
+ print("An error has occurred during translation:",e)
209
+ print("Retrying...")
210
+ flag = True
211
  srt.set_translation(translate, range)
212
 
213
  srt.check_len_and_split()