yichenl5 commited on
Commit
491821e
2 Parent(s): bf542d7 c9578de

Merge pull request #33 from project-kxkg/SRT_cleanup

Browse files

Srt cleanup

Former-commit-id: 9a454ba96523d6202b5229ffd620c78a22c3db4b

Files changed (2) hide show
  1. .gitignore +6 -5
  2. SRT.py +186 -143
.gitignore CHANGED
@@ -1,10 +1,11 @@
1
- /downloads
2
- /results
3
  .DS_Store
4
- /__pycache__
 
 
 
5
  test.py
6
  test.srt
7
  test.txt
8
  log_*.csv
9
- log.csv
10
- /test
 
1
+ __pycache__/
 
2
  .DS_Store
3
+ .idea/
4
+ downloads/
5
+ results/
6
+ test/
7
  test.py
8
  test.srt
9
  test.txt
10
  log_*.csv
11
+ log.csv
 
SRT.py CHANGED
@@ -1,10 +1,11 @@
1
- from datetime import timedelta
2
- from csv import reader
3
- from datetime import datetime
4
  import re
 
 
 
 
5
  import openai
6
- import os
7
- from collections import deque
8
 
9
  class SRT_segment(object):
10
  def __init__(self, *args) -> None:
@@ -12,22 +13,24 @@ class SRT_segment(object):
12
  segment = args[0]
13
  self.start = segment['start']
14
  self.end = segment['end']
15
- self.start_ms = int((segment['start']*100)%100*10)
16
- self.end_ms = int((segment['end']*100)%100*10)
17
 
18
- if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
19
- self.end_ms+=500
20
 
21
  self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
22
  self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
23
  if self.start_ms == 0:
24
- self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000'
25
  else:
26
- self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+str(self.start_time).split('.')[1][:3]
 
27
  if self.end_ms == 0:
28
- self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000'
29
  else:
30
- self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+str(self.end_time).split('.')[1][:3]
 
31
  self.source_text = segment['text'].lstrip()
32
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
33
  self.translation = ""
@@ -39,15 +42,21 @@ class SRT_segment(object):
39
  self.end_time_str = self.duration.split(" --> ")[1]
40
 
41
  # parse the time to float
42
- self.start_ms = int(self.start_time_str.split(',')[1])/10
43
- self.end_ms = int(self.end_time_str.split(',')[1])/10
44
  start_list = self.start_time_str.split(',')[0].split(':')
45
- self.start = int(start_list[0])*3600 + int(start_list[1])*60 + int(start_list[2]) + self.start_ms/100
46
  end_list = self.end_time_str.split(',')[0].split(':')
47
- self.end = int(end_list[0])*3600 + int(end_list[1])*60 + int(end_list[2]) + self.end_ms/100
48
  self.translation = ""
49
-
50
  def merge_seg(self, seg):
 
 
 
 
 
 
51
  self.source_text += f' {seg.source_text}'
52
  self.translation += f' {seg.translation}'
53
  self.end_time_str = seg.end_time_str
@@ -56,22 +65,42 @@ class SRT_segment(object):
56
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
57
  pass
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def remove_trans_punc(self):
60
- # remove punctuations in translation text
61
- self.translation = self.translation.replace(',', ' ')
62
- self.translation = self.translation.replace('。', ' ')
63
- self.translation = self.translation.replace('!', ' ')
64
- self.translation = self.translation.replace('?', ' ')
 
 
65
 
66
  def __str__(self) -> str:
67
- return f'{self.duration}\n{self.source_text}\n\n'
68
-
69
  def get_trans_str(self) -> str:
70
  return f'{self.duration}\n{self.translation}\n\n'
71
-
72
  def get_bilingual_str(self) -> str:
73
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
74
 
 
75
  class SRT_script():
76
  def __init__(self, segments) -> None:
77
  self.segments = []
@@ -80,29 +109,41 @@ class SRT_script():
80
  self.segments.append(srt_seg)
81
 
82
  @classmethod
83
- def parse_from_srt_file(cls, path:str):
84
  with open(path, 'r', encoding="utf-8") as f:
85
- script_lines = f.read().splitlines()
86
 
87
  segments = []
88
  for i in range(len(script_lines)):
89
  if i % 4 == 0:
90
- segments.append(list(script_lines[i:i+4]))
91
 
92
  return cls(segments)
93
 
94
  def merge_segs(self, idx_list) -> SRT_segment:
95
- final_seg = self.segments[idx_list[0]]
 
 
 
 
 
 
 
96
  if len(idx_list) == 1:
97
- return final_seg
98
-
99
  for idx in range(1, len(idx_list)):
100
- final_seg.merge_seg(self.segments[idx_list[idx]])
101
-
102
- return final_seg
103
 
104
  def form_whole_sentence(self):
105
- merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
 
 
 
 
 
106
  sentence = []
107
  for i, seg in enumerate(self.segments):
108
  if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
@@ -117,113 +158,117 @@ class SRT_script():
117
  segments.append(self.merge_segs(idx_list))
118
 
119
  self.segments = segments
120
-
121
  def remove_trans_punctuation(self):
122
- # Post-process: remove all punc after translation and split
 
 
 
123
  for i, seg in enumerate(self.segments):
124
  seg.remove_trans_punc()
125
 
126
- def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
127
  start_seg_id = id_range[0]
128
  end_seg_id = id_range[1]
129
-
130
  src_text = ""
131
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
132
- src_text+=seg.source_text
133
- src_text+='\n\n'
134
 
135
- def inner_func(target,input_str):
136
  response = openai.ChatCompletion.create(
137
- #model=model,
138
- model = "gpt-3.5-turbo",
139
- messages = [
140
- #{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
141
- #{"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
142
- #{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
143
- #{"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
144
-
145
- {"role": "system", "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
146
- {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
147
- {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target,input_str)}
148
- # {"role": "system", "content": "请将以下中文与其英文句子一一对应并输出:"},
149
- # {"role": "system", "content": "英文:{}".format(src_text)},
150
- # {"role": "user", "content": "中文:{}\n\n".format(input_str)},
151
- ],
152
- temperature = 0.15
153
- )
 
154
  # print(src_text)
155
  # print(input_str)
156
  # print(response['choices'][0]['message']['content'].strip())
157
  # exit()
158
  return response['choices'][0]['message']['content'].strip()
159
-
160
-
161
  lines = translate.split('\n\n')
162
  if len(lines) < (end_seg_id - start_seg_id + 1):
163
  count = 0
164
  solved = True
165
- while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
166
  count += 1
167
  print("Solving Unmatched Lines|iteration {}".format(count))
168
- #input_str = "\n"
169
- #initialize GPT input
170
- #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
171
  # input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
172
  # #Append to prompt string
173
  # #Adds sentence index let GPT keep track of sentence breaks
174
- #input_str += translate
175
- #append translate to prompt
176
  flag = True
177
  while flag:
178
  flag = False
179
- #print("translate:")
180
- #print(translate)
181
  try:
182
- #print("target")
183
- #print(end_seg_id - start_seg_id + 1)
184
- translate = inner_func(end_seg_id - start_seg_id + 1,translate)
185
  except Exception as e:
186
- print("An error has occurred during solving unmatched lines:",e)
187
  print("Retrying...")
188
  flag = True
189
  lines = translate.split('\n')
190
- #print("result")
191
- #print(len(lines))
192
-
193
  if len(lines) < (end_seg_id - start_seg_id + 1):
194
  solved = False
195
  print("Failed Solving unmatched lines, Manually parse needed")
196
-
197
  if not os.path.exists("./logs"):
198
  os.mkdir("./logs")
199
  if video_link:
200
  log_file = "./logs/log_link.csv"
201
  log_exist = os.path.exists(log_file)
202
- with open(log_file,"a") as log:
203
  if not log_exist:
204
  log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
205
- log.write(str(id_range)+','+str(count)+','+str(solved)+','+str(len(self.segments))+','+video_link + "\n")
 
206
  else:
207
  log_file = "./logs/log_name.csv"
208
  log_exist = os.path.exists(log_file)
209
- with open(log_file,"a") as log:
210
  if not log_exist:
211
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
212
- log.write(str(id_range)+','+str(count)+','+str(solved)+','+str(len(self.segments))+','+video_name + "\n")
213
-
 
214
  print(lines)
215
- #print(id_range)
216
- #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
217
  # print(seg.source_text)
218
- #print(translate)
219
-
220
-
221
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
222
  # naive way to due with merge translation problem
223
  # TODO: need a smarter solution
224
 
225
  if i < len(lines):
226
- if "Note:" in lines[i]: # to avoid note
227
  lines.remove(lines[i])
228
  max_num -= 1
229
  if i == len(lines) - 1:
@@ -233,7 +278,6 @@ class SRT_script():
233
  except:
234
  seg.translation = lines[i]
235
 
236
-
237
  def split_seg(self, seg, text_threshold, time_threshold):
238
  # evenly split seg to 2 parts and add new seg into self.segments
239
 
@@ -251,21 +295,24 @@ class SRT_script():
251
  src_commas = [m.start() for m in re.finditer(',', source_text)]
252
  trans_commas = [m.start() for m in re.finditer(',', translation)]
253
  if len(src_commas) != 0:
254
- src_split_idx = src_commas[len(src_commas)//2] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2 - 1]
 
255
  else:
256
  src_space = [m.start() for m in re.finditer(' ', source_text)]
257
- if len(src_space) > 0:
258
- src_split_idx = src_space[len(src_space)//2] if len(src_space) % 2 == 1 else src_space[len(src_space)//2 - 1]
 
259
  else:
260
  src_split_idx = 0
261
 
262
  if len(trans_commas) != 0:
263
- trans_split_idx = trans_commas[len(trans_commas)//2] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2 - 1]
 
264
  else:
265
- trans_split_idx = len(translation)//2
266
-
267
  # split the time duration based on text length
268
- time_split_ratio = trans_split_idx/(len(seg.translation) - 1)
269
 
270
  src_seg1 = source_text[:src_split_idx]
271
  src_seg2 = source_text[src_split_idx:]
@@ -273,7 +320,7 @@ class SRT_script():
273
  trans_seg2 = translation[trans_split_idx:]
274
 
275
  start_seg1 = seg.start
276
- end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)*time_split_ratio
277
  end_seg2 = seg.end
278
 
279
  seg1_dict = {}
@@ -295,7 +342,7 @@ class SRT_script():
295
  result_list += self.split_seg(seg1, text_threshold, time_threshold)
296
  else:
297
  result_list.append(seg1)
298
-
299
  if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
300
  result_list += self.split_seg(seg2, text_threshold, time_threshold)
301
  else:
@@ -303,7 +350,6 @@ class SRT_script():
303
 
304
  return result_list
305
 
306
-
307
  def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
308
  # DEPRECATED
309
  # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
@@ -314,7 +360,7 @@ class SRT_script():
314
  segments += seg_list
315
  else:
316
  segments.append(seg)
317
-
318
  self.segments = segments
319
 
320
  pass
@@ -325,23 +371,23 @@ class SRT_script():
325
  end_seg_id = range[1]
326
  extra_len = 0
327
  segments = []
328
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
329
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
330
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
331
  segments += seg_list
332
- extra_len += len(seg_list) - 1
333
  else:
334
  segments.append(seg)
335
-
336
- self.segments[start_seg_id-1:end_seg_id] = segments
337
  return extra_len
338
 
339
  def correct_with_force_term(self):
340
  ## force term correction
341
 
342
  # load term dictionary
343
- with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
344
- term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
345
 
346
  # change term
347
  for seg in self.segments:
@@ -359,7 +405,7 @@ class SRT_script():
359
 
360
  def spell_check_term(self):
361
  ## known bug: I've will be replaced because i've is not in the dict
362
-
363
  import enchant
364
  dict = enchant.Dict('en_US')
365
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
@@ -371,6 +417,7 @@ class SRT_script():
371
  [real_word, pos] = self.get_real_word(word)
372
  if not dict.check(word[:pos]):
373
  suggest = term_spellDict.suggest(real_word)
 
374
  if suggest and enchant.utils.levenshtein(word, suggest[0]) < (len(word)+len(suggest[0]))/4: # relax spell check
375
 
376
  #with open("dislog.log","a") as log:
@@ -389,14 +436,13 @@ class SRT_script():
389
  seg.source_text = " ".join(ready_words)
390
  pass
391
 
392
- def spell_correction(self, word:str, arg:int):
393
  try:
394
- arg in [0,1]
395
  except ValueError:
396
  print('only 0 or 1 for argument')
397
 
398
-
399
- def uncover(word:str):
400
  if word[-2:] == ".\n":
401
  real_word = word[:-2].lower()
402
  n = -2
@@ -406,14 +452,14 @@ class SRT_script():
406
  else:
407
  real_word = word.lower()
408
  n = 0
409
- return real_word, len(word)+n
410
-
411
  real_word = uncover(word)[0]
412
  pos = uncover(word)[1]
413
  new_word = word
414
  if arg == 0: # term translate mode
415
- with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
416
- term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
417
  if real_word in term_enzh_dict:
418
  new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
419
  elif arg == 1: # term spell check mode
@@ -422,10 +468,10 @@ class SRT_script():
422
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
423
  if not dict.check(real_word):
424
  if term_spellDict.suggest(real_word): # relax spell check
425
- new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0])
426
  return new_word
427
-
428
- def get_real_word(self, word:str):
429
  if word[-2:] == ".\n":
430
  real_word = word[:-2].lower()
431
  n = -2
@@ -435,8 +481,7 @@ class SRT_script():
435
  else:
436
  real_word = word.lower()
437
  n = 0
438
- return real_word, len(word)+n
439
-
440
 
441
  ## WRITE AND READ FUNCTIONS ##
442
 
@@ -444,48 +489,48 @@ class SRT_script():
444
  # return a string with pure source text
445
  result = ""
446
  for i, seg in enumerate(self.segments):
447
- result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
448
-
449
  return result
450
-
451
  def reform_src_str(self):
452
  result = ""
453
  for i, seg in enumerate(self.segments):
454
- result += f'{i+1}\n'
455
  result += str(seg)
456
  return result
457
 
458
  def reform_trans_str(self):
459
  result = ""
460
  for i, seg in enumerate(self.segments):
461
- result += f'{i+1}\n'
462
  result += seg.get_trans_str()
463
  return result
464
-
465
  def form_bilingual_str(self):
466
  result = ""
467
  for i, seg in enumerate(self.segments):
468
- result += f'{i+1}\n'
469
  result += seg.get_bilingual_str()
470
  return result
471
 
472
- def write_srt_file_src(self, path:str):
473
  # write srt file to path
474
  with open(path, "w", encoding='utf-8') as f:
475
  f.write(self.reform_src_str())
476
  pass
477
 
478
- def write_srt_file_translate(self, path:str):
479
  with open(path, "w", encoding='utf-8') as f:
480
  f.write(self.reform_trans_str())
481
  pass
482
 
483
- def write_srt_file_bilingual(self, path:str):
484
  with open(path, "w", encoding='utf-8') as f:
485
  f.write(self.form_bilingual_str())
486
  pass
487
 
488
- def realtime_write_srt(self,path,range,length, idx):
489
  # DEPRECATED
490
  start_seg_id = range[0]
491
  end_seg_id = range[1]
@@ -494,22 +539,20 @@ class SRT_script():
494
  # f.write(f'{i+idx}\n')
495
  # f.write(seg.get_trans_str())
496
  for i, seg in enumerate(self.segments):
497
- if i<range[0]-1 : continue
498
- if i>=range[1] + length :break
499
- f.write(f'{i+idx}\n')
500
  f.write(seg.get_trans_str())
501
  pass
502
 
503
- def realtime_bilingual_write_srt(self,path,range, length,idx):
504
  # DEPRECATED
505
  start_seg_id = range[0]
506
  end_seg_id = range[1]
507
  with open(path, "a", encoding='utf-8') as f:
508
  for i, seg in enumerate(self.segments):
509
- if i<range[0]-1 : continue
510
- if i>=range[1] + length :break
511
- f.write(f'{i+idx}\n')
512
  f.write(seg.get_bilingual_str())
513
  pass
514
-
515
-
 
1
+ import os
 
 
2
  import re
3
+ from copy import copy, deepcopy
4
+ from csv import reader
5
+ from datetime import timedelta
6
+
7
  import openai
8
+
 
9
 
10
  class SRT_segment(object):
11
  def __init__(self, *args) -> None:
 
13
  segment = args[0]
14
  self.start = segment['start']
15
  self.end = segment['end']
16
+ self.start_ms = int((segment['start'] * 100) % 100 * 10)
17
+ self.end_ms = int((segment['end'] * 100) % 100 * 10)
18
 
19
+ if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
20
+ self.end_ms += 500
21
 
22
  self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
23
  self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
24
  if self.start_ms == 0:
25
+ self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000'
26
  else:
27
+ self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \
28
+ str(self.start_time).split('.')[1][:3]
29
  if self.end_ms == 0:
30
+ self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000'
31
  else:
32
+ self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][
33
+ :3]
34
  self.source_text = segment['text'].lstrip()
35
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
36
  self.translation = ""
 
42
  self.end_time_str = self.duration.split(" --> ")[1]
43
 
44
  # parse the time to float
45
+ self.start_ms = int(self.start_time_str.split(',')[1]) / 10
46
+ self.end_ms = int(self.end_time_str.split(',')[1]) / 10
47
  start_list = self.start_time_str.split(',')[0].split(':')
48
+ self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
49
  end_list = self.end_time_str.split(',')[0].split(':')
50
+ self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
51
  self.translation = ""
52
+
53
  def merge_seg(self, seg):
54
+ """
55
+ Merge the segment seg with the current segment in place.
56
+ :param seg: Another segment that is strictly next to current one.
57
+ :return: None
58
+ """
59
+ # assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments."
60
  self.source_text += f' {seg.source_text}'
61
  self.translation += f' {seg.translation}'
62
  self.end_time_str = seg.end_time_str
 
65
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
66
  pass
67
 
68
+ def __add__(self, other):
69
+ """
70
+ Merge the segment seg with the current segment, and return the new constructed segment.
71
+ No in-place modification.
72
+ :param other: Another segment that is strictly next to added segment.
73
+ :return: new segment of the two sub-segments
74
+ """
75
+ # assert other.start_ms == self.end_ms, f"cannot merge discontinuous segments."
76
+ result = deepcopy(self)
77
+ result.source_text += f' {other.source_text}'
78
+ result.translation += f' {other.translation}'
79
+ result.end_time_str = other.end_time_str
80
+ result.end = other.end
81
+ result.end_ms = other.end_ms
82
+ result.duration = f"{self.start_time_str} --> {self.end_time_str}"
83
+ return result
84
+
85
  def remove_trans_punc(self):
86
+ """
87
+ remove punctuations in translation text
88
+ :return: None
89
+ """
90
+ punc_cn = ",。!?"
91
+ translator = str.maketrans(punc_cn, ' ' * len(punc_cn))
92
+ self.translation = self.translation.translate(translator)
93
 
94
  def __str__(self) -> str:
95
+ return f'{self.duration}\n{self.source_text}\n\n'
96
+
97
  def get_trans_str(self) -> str:
98
  return f'{self.duration}\n{self.translation}\n\n'
99
+
100
  def get_bilingual_str(self) -> str:
101
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
102
 
103
+
104
  class SRT_script():
105
  def __init__(self, segments) -> None:
106
  self.segments = []
 
109
  self.segments.append(srt_seg)
110
 
111
  @classmethod
112
+ def parse_from_srt_file(cls, path: str):
113
  with open(path, 'r', encoding="utf-8") as f:
114
+ script_lines = [line.rstrip() for line in f.readlines()]
115
 
116
  segments = []
117
  for i in range(len(script_lines)):
118
  if i % 4 == 0:
119
+ segments.append(list(script_lines[i:i + 4]))
120
 
121
  return cls(segments)
122
 
123
  def merge_segs(self, idx_list) -> SRT_segment:
124
+ """
125
+ Merge entire segment list to a single segment
126
+ :param idx_list: List of index to merge
127
+ :return: Merged list
128
+ """
129
+ if not idx_list:
130
+ raise NotImplementedError('Empty idx_list')
131
+ seg_result = deepcopy(self.segments[idx_list[0]])
132
  if len(idx_list) == 1:
133
+ return seg_result
134
+
135
  for idx in range(1, len(idx_list)):
136
+ seg_result += self.segments[idx_list[idx]]
137
+
138
+ return seg_result
139
 
140
  def form_whole_sentence(self):
141
+ """
142
+ Concatenate or Strip sentences and reconstruct segments list. This is because of
143
+ improper segmentation from openai-whisper.
144
+ :return: None
145
+ """
146
+ merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
147
  sentence = []
148
  for i, seg in enumerate(self.segments):
149
  if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
 
158
  segments.append(self.merge_segs(idx_list))
159
 
160
  self.segments = segments
161
+
162
  def remove_trans_punctuation(self):
163
+ """
164
+ Post-process: remove all punc after translation and split
165
+ :return: None
166
+ """
167
  for i, seg in enumerate(self.segments):
168
  seg.remove_trans_punc()
169
 
170
+ def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None):
171
  start_seg_id = id_range[0]
172
  end_seg_id = id_range[1]
173
+
174
  src_text = ""
175
+ for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
176
+ src_text += seg.source_text
177
+ src_text += '\n\n'
178
 
179
+ def inner_func(target, input_str):
180
  response = openai.ChatCompletion.create(
181
+ # model=model,
182
+ model="gpt-3.5-turbo",
183
+ messages=[
184
+ # {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
185
+ # {"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
186
+ # {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
187
+ # {"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
188
+
189
+ {"role": "system",
190
+ "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
191
+ {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
192
+ {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)}
193
+ # {"role": "system", "content": "请将以下中文与其英文句子一一对应并输出:"},
194
+ # {"role": "system", "content": "英文:{}".format(src_text)},
195
+ # {"role": "user", "content": "中文:{}\n\n".format(input_str)},
196
+ ],
197
+ temperature=0.15
198
+ )
199
  # print(src_text)
200
  # print(input_str)
201
  # print(response['choices'][0]['message']['content'].strip())
202
  # exit()
203
  return response['choices'][0]['message']['content'].strip()
204
+
 
205
  lines = translate.split('\n\n')
206
  if len(lines) < (end_seg_id - start_seg_id + 1):
207
  count = 0
208
  solved = True
209
+ while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
210
  count += 1
211
  print("Solving Unmatched Lines|iteration {}".format(count))
212
+ # input_str = "\n"
213
+ # initialize GPT input
214
+ # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
215
  # input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
216
  # #Append to prompt string
217
  # #Adds sentence index let GPT keep track of sentence breaks
218
+ # input_str += translate
219
+ # append translate to prompt
220
  flag = True
221
  while flag:
222
  flag = False
223
+ # print("translate:")
224
+ # print(translate)
225
  try:
226
+ # print("target")
227
+ # print(end_seg_id - start_seg_id + 1)
228
+ translate = inner_func(end_seg_id - start_seg_id + 1, translate)
229
  except Exception as e:
230
+ print("An error has occurred during solving unmatched lines:", e)
231
  print("Retrying...")
232
  flag = True
233
  lines = translate.split('\n')
234
+ # print("result")
235
+ # print(len(lines))
236
+
237
  if len(lines) < (end_seg_id - start_seg_id + 1):
238
  solved = False
239
  print("Failed Solving unmatched lines, Manually parse needed")
240
+
241
  if not os.path.exists("./logs"):
242
  os.mkdir("./logs")
243
  if video_link:
244
  log_file = "./logs/log_link.csv"
245
  log_exist = os.path.exists(log_file)
246
+ with open(log_file, "a") as log:
247
  if not log_exist:
248
  log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
249
+ log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
250
+ len(self.segments)) + ',' + video_link + "\n")
251
  else:
252
  log_file = "./logs/log_name.csv"
253
  log_exist = os.path.exists(log_file)
254
+ with open(log_file, "a") as log:
255
  if not log_exist:
256
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
257
+ log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
258
+ len(self.segments)) + ',' + video_name + "\n")
259
+
260
  print(lines)
261
+ # print(id_range)
262
+ # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
263
  # print(seg.source_text)
264
+ # print(translate)
265
+
266
+ for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
 
267
  # naive way to due with merge translation problem
268
  # TODO: need a smarter solution
269
 
270
  if i < len(lines):
271
+ if "Note:" in lines[i]: # to avoid note
272
  lines.remove(lines[i])
273
  max_num -= 1
274
  if i == len(lines) - 1:
 
278
  except:
279
  seg.translation = lines[i]
280
 
 
281
  def split_seg(self, seg, text_threshold, time_threshold):
282
  # evenly split seg to 2 parts and add new seg into self.segments
283
 
 
295
  src_commas = [m.start() for m in re.finditer(',', source_text)]
296
  trans_commas = [m.start() for m in re.finditer(',', translation)]
297
  if len(src_commas) != 0:
298
+ src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
299
+ len(src_commas) // 2 - 1]
300
  else:
301
  src_space = [m.start() for m in re.finditer(' ', source_text)]
302
+ if len(src_space) > 0:
303
+ src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
304
+ len(src_space) // 2 - 1]
305
  else:
306
  src_split_idx = 0
307
 
308
  if len(trans_commas) != 0:
309
+ trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
310
+ len(trans_commas) // 2 - 1]
311
  else:
312
+ trans_split_idx = len(translation) // 2
313
+
314
  # split the time duration based on text length
315
+ time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
316
 
317
  src_seg1 = source_text[:src_split_idx]
318
  src_seg2 = source_text[src_split_idx:]
 
320
  trans_seg2 = translation[trans_split_idx:]
321
 
322
  start_seg1 = seg.start
323
+ end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio
324
  end_seg2 = seg.end
325
 
326
  seg1_dict = {}
 
342
  result_list += self.split_seg(seg1, text_threshold, time_threshold)
343
  else:
344
  result_list.append(seg1)
345
+
346
  if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
347
  result_list += self.split_seg(seg2, text_threshold, time_threshold)
348
  else:
 
350
 
351
  return result_list
352
 
 
353
  def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
354
  # DEPRECATED
355
  # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
 
360
  segments += seg_list
361
  else:
362
  segments.append(seg)
363
+
364
  self.segments = segments
365
 
366
  pass
 
371
  end_seg_id = range[1]
372
  extra_len = 0
373
  segments = []
374
+ for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
375
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
376
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
377
  segments += seg_list
378
+ extra_len += len(seg_list) - 1
379
  else:
380
  segments.append(seg)
381
+
382
+ self.segments[start_seg_id - 1:end_seg_id] = segments
383
  return extra_len
384
 
385
  def correct_with_force_term(self):
386
  ## force term correction
387
 
388
  # load term dictionary
389
+ with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
390
+ term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
391
 
392
  # change term
393
  for seg in self.segments:
 
405
 
406
  def spell_check_term(self):
407
  ## known bug: I've will be replaced because i've is not in the dict
408
+
409
  import enchant
410
  dict = enchant.Dict('en_US')
411
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
 
417
  [real_word, pos] = self.get_real_word(word)
418
  if not dict.check(word[:pos]):
419
  suggest = term_spellDict.suggest(real_word)
420
+
421
  if suggest and enchant.utils.levenshtein(word, suggest[0]) < (len(word)+len(suggest[0]))/4: # relax spell check
422
 
423
  #with open("dislog.log","a") as log:
 
436
  seg.source_text = " ".join(ready_words)
437
  pass
438
 
439
+ def spell_correction(self, word: str, arg: int):
440
  try:
441
+ arg in [0, 1]
442
  except ValueError:
443
  print('only 0 or 1 for argument')
444
 
445
+ def uncover(word: str):
 
446
  if word[-2:] == ".\n":
447
  real_word = word[:-2].lower()
448
  n = -2
 
452
  else:
453
  real_word = word.lower()
454
  n = 0
455
+ return real_word, len(word) + n
456
+
457
  real_word = uncover(word)[0]
458
  pos = uncover(word)[1]
459
  new_word = word
460
  if arg == 0: # term translate mode
461
+ with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
462
+ term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
463
  if real_word in term_enzh_dict:
464
  new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
465
  elif arg == 1: # term spell check mode
 
468
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
469
  if not dict.check(real_word):
470
  if term_spellDict.suggest(real_word): # relax spell check
471
+ new_word = word.replace(word[:pos], term_spellDict.suggest(real_word)[0])
472
  return new_word
473
+
474
+ def get_real_word(self, word: str):
475
  if word[-2:] == ".\n":
476
  real_word = word[:-2].lower()
477
  n = -2
 
481
  else:
482
  real_word = word.lower()
483
  n = 0
484
+ return real_word, len(word) + n
 
485
 
486
  ## WRITE AND READ FUNCTIONS ##
487
 
 
489
  # return a string with pure source text
490
  result = ""
491
  for i, seg in enumerate(self.segments):
492
+ result += f'SENTENCE {i + 1}: {seg.source_text}\n\n\n'
493
+
494
  return result
495
+
496
  def reform_src_str(self):
497
  result = ""
498
  for i, seg in enumerate(self.segments):
499
+ result += f'{i + 1}\n'
500
  result += str(seg)
501
  return result
502
 
503
  def reform_trans_str(self):
504
  result = ""
505
  for i, seg in enumerate(self.segments):
506
+ result += f'{i + 1}\n'
507
  result += seg.get_trans_str()
508
  return result
509
+
510
  def form_bilingual_str(self):
511
  result = ""
512
  for i, seg in enumerate(self.segments):
513
+ result += f'{i + 1}\n'
514
  result += seg.get_bilingual_str()
515
  return result
516
 
517
+ def write_srt_file_src(self, path: str):
518
  # write srt file to path
519
  with open(path, "w", encoding='utf-8') as f:
520
  f.write(self.reform_src_str())
521
  pass
522
 
523
+ def write_srt_file_translate(self, path: str):
524
  with open(path, "w", encoding='utf-8') as f:
525
  f.write(self.reform_trans_str())
526
  pass
527
 
528
+ def write_srt_file_bilingual(self, path: str):
529
  with open(path, "w", encoding='utf-8') as f:
530
  f.write(self.form_bilingual_str())
531
  pass
532
 
533
+ def realtime_write_srt(self, path, range, length, idx):
534
  # DEPRECATED
535
  start_seg_id = range[0]
536
  end_seg_id = range[1]
 
539
  # f.write(f'{i+idx}\n')
540
  # f.write(seg.get_trans_str())
541
  for i, seg in enumerate(self.segments):
542
+ if i < range[0] - 1: continue
543
+ if i >= range[1] + length: break
544
+ f.write(f'{i + idx}\n')
545
  f.write(seg.get_trans_str())
546
  pass
547
 
548
+ def realtime_bilingual_write_srt(self, path, range, length, idx):
549
  # DEPRECATED
550
  start_seg_id = range[0]
551
  end_seg_id = range[1]
552
  with open(path, "a", encoding='utf-8') as f:
553
  for i, seg in enumerate(self.segments):
554
+ if i < range[0] - 1: continue
555
+ if i >= range[1] + length: break
556
+ f.write(f'{i + idx}\n')
557
  f.write(seg.get_bilingual_str())
558
  pass