Eason Lu commited on
Commit
9b3283d
1 Parent(s): 84e3cbf

solve split problem; add punctuation remove

Browse files

Former-commit-id: 8125cf63dcc38aa00e2eb10c513daa1d2ca3aeb5

Files changed (2) hide show
  1. SRT.py +94 -68
  2. pipeline.py +2 -1
SRT.py CHANGED
@@ -51,9 +51,18 @@ class SRT_segment(object):
51
  self.source_text += seg.source_text
52
  self.translation += seg.translation
53
  self.end_time_str = seg.end_time_str
 
 
54
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
55
  pass
56
-
 
 
 
 
 
 
 
57
  def __str__(self) -> str:
58
  return f'{self.duration}\n{self.source_text}\n\n'
59
 
@@ -96,7 +105,7 @@ class SRT_script():
96
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
97
  sentence = []
98
  for i, seg in enumerate(self.segments):
99
- if seg.source_text[-1] == '.':
100
  sentence.append(i)
101
  merge_list.append(sentence)
102
  sentence = []
@@ -108,7 +117,11 @@ class SRT_script():
108
  segments.append(self.merge_segs(idx_list))
109
 
110
  self.segments = segments # need memory release?
111
-
 
 
 
 
112
 
113
  def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
114
  start_seg_id = id_range[0]
@@ -197,10 +210,12 @@ class SRT_script():
197
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
198
  # naive way to due with merge translation problem
199
  # TODO: need a smarter solution
200
-
201
- if i < len(lines):
202
  if "(Note:" in lines[i]: # to avoid note
 
203
  lines.remove(lines[i])
 
204
  if i == len(lines) - 1:
205
  break
206
  try:
@@ -209,7 +224,7 @@ class SRT_script():
209
  seg.translation = lines[i]
210
 
211
 
212
- def split_seg(self, seg, threshold):
213
  # evenly split seg to 2 parts and add new seg into self.segments
214
 
215
  # ignore the initial comma to solve the recursion problem
@@ -221,6 +236,8 @@ class SRT_script():
221
 
222
  source_text = seg.source_text
223
  translation = seg.translation
 
 
224
  src_commas = [m.start() for m in re.finditer(',', source_text)]
225
  trans_commas = [m.start() for m in re.finditer(',', translation)]
226
  if len(src_commas) != 0:
@@ -237,13 +254,18 @@ class SRT_script():
237
  else:
238
  trans_split_idx = len(translation)//2
239
 
 
 
 
240
  src_seg1 = source_text[:src_split_idx]
241
  src_seg2 = source_text[src_split_idx:]
242
  trans_seg1 = translation[:trans_split_idx]
243
  trans_seg2 = translation[trans_split_idx:]
 
244
  start_seg1 = seg.start
245
- end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2
246
  end_seg2 = seg.end
 
247
  seg1_dict = {}
248
  seg1_dict['text'] = src_seg1
249
  seg1_dict['start'] = start_seg1
@@ -259,26 +281,26 @@ class SRT_script():
259
  seg2.translation = trans_seg2
260
 
261
  result_list = []
262
- if len(seg1.translation) > threshold:
263
- result_list += self.split_seg(seg1, threshold)
264
  else:
265
  result_list.append(seg1)
266
 
267
- if len(seg2.translation) > threshold:
268
- result_list += self.split_seg(seg2, threshold)
269
  else:
270
  result_list.append(seg2)
271
 
272
  return result_list
273
 
274
 
275
- def check_len_and_split(self, threshold=30):
276
  # DEPRECATED
277
- # if sentence length >= threshold, split this segments to two
278
  segments = []
279
  for seg in self.segments:
280
- if len(seg.translation) > threshold:
281
- seg_list = self.split_seg(seg, threshold)
282
  segments += seg_list
283
  else:
284
  segments.append(seg)
@@ -287,73 +309,25 @@ class SRT_script():
287
 
288
  pass
289
 
290
- def check_len_and_split_range(self, range, threshold=30):
291
- # if sentence length >= threshold, split this segments to two
292
  start_seg_id = range[0]
293
  end_seg_id = range[1]
294
  extra_len = 0
295
  segments = []
296
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
297
- if len(seg.translation) > threshold:
298
- seg_list = self.split_seg(seg, threshold)
299
  segments += seg_list
300
  extra_len += len(seg_list) - 1
301
  else:
302
  segments.append(seg)
303
 
304
  self.segments[start_seg_id-1:end_seg_id] = segments
305
-
306
  return extra_len
307
-
308
- def get_source_only(self):
309
- # return a string with pure source text
310
- result = ""
311
- for i, seg in enumerate(self.segments):
312
- result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
313
-
314
- return result
315
-
316
- def reform_src_str(self):
317
- result = ""
318
- for i, seg in enumerate(self.segments):
319
- result += f'{i+1}\n'
320
- result += str(seg)
321
- return result
322
-
323
- def reform_trans_str(self):
324
- result = ""
325
- for i, seg in enumerate(self.segments):
326
- result += f'{i+1}\n'
327
- result += seg.get_trans_str()
328
- return result
329
-
330
- def form_bilingual_str(self):
331
- result = ""
332
- for i, seg in enumerate(self.segments):
333
- result += f'{i+1}\n'
334
- result += seg.get_bilingual_str()
335
- return result
336
-
337
- def write_srt_file_src(self, path:str):
338
- # write srt file to path
339
- with open(path, "w", encoding='utf-8') as f:
340
- f.write(self.reform_src_str())
341
- pass
342
-
343
- def write_srt_file_translate(self, path:str):
344
- with open(path, "w", encoding='utf-8') as f:
345
- f.write(self.reform_trans_str())
346
- pass
347
-
348
- def write_srt_file_bilingual(self, path:str):
349
- with open(path, "w", encoding='utf-8') as f:
350
- f.write(self.form_bilingual_str())
351
- pass
352
 
353
  def correct_with_force_term(self):
354
  ## force term correction
355
- # TODO: shortcut translation i.e. VA, ob
356
- # TODO: variety of translation
357
 
358
  # load term dictionary
359
  with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
@@ -442,8 +416,57 @@ class SRT_script():
442
  real_word = word.lower()
443
  n = 0
444
  return real_word, len(word)+n
 
 
 
 
 
 
 
 
 
 
 
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  def realtime_write_srt(self,path,range,length, idx):
 
447
  start_seg_id = range[0]
448
  end_seg_id = range[1]
449
  with open(path, "a", encoding='utf-8') as f:
@@ -458,6 +481,7 @@ class SRT_script():
458
  pass
459
 
460
  def realtime_bilingual_write_srt(self,path,range, length,idx):
 
461
  start_seg_id = range[0]
462
  end_seg_id = range[1]
463
  with open(path, "a", encoding='utf-8') as f:
@@ -466,4 +490,6 @@ class SRT_script():
466
  if i>=range[1] + length :break
467
  f.write(f'{i+idx}\n')
468
  f.write(seg.get_bilingual_str())
469
- pass
 
 
 
51
  self.source_text += seg.source_text
52
  self.translation += seg.translation
53
  self.end_time_str = seg.end_time_str
54
+ self.end = seg.end
55
+ self.end_ms = seg.end_ms
56
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
57
  pass
58
+
59
+ def remove_trans_punc(self):
60
+ # remove punctuations in translation text
61
+ self.translation = self.translation.replace(',', ' ')
62
+ self.translation = self.translation.replace('。', ' ')
63
+ self.translation = self.translation.replace('!', ' ')
64
+ self.translation = self.translation.replace('?', ' ')
65
+
66
  def __str__(self) -> str:
67
  return f'{self.duration}\n{self.source_text}\n\n'
68
 
 
105
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
106
  sentence = []
107
  for i, seg in enumerate(self.segments):
108
+ if seg.source_text[-1] in ['.', '!', '?']:
109
  sentence.append(i)
110
  merge_list.append(sentence)
111
  sentence = []
 
117
  segments.append(self.merge_segs(idx_list))
118
 
119
  self.segments = segments # need memory release?
120
+
121
+ def remove_trans_punctuation(self):
122
+ # Post-process: remove all punc after translation and split
123
+ for i, seg in enumerate(self.segments):
124
+ seg.remove_trans_punc()
125
 
126
  def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
127
  start_seg_id = id_range[0]
 
210
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
211
  # naive way to due with merge translation problem
212
  # TODO: need a smarter solution
213
+ max_num = len(lines)
214
+ if i < max_num:
215
  if "(Note:" in lines[i]: # to avoid note
216
+ print(f'remove: {lines[i]}')
217
  lines.remove(lines[i])
218
+ max_num -= 1
219
  if i == len(lines) - 1:
220
  break
221
  try:
 
224
  seg.translation = lines[i]
225
 
226
 
227
+ def split_seg(self, seg, text_threshold, time_threshold):
228
  # evenly split seg to 2 parts and add new seg into self.segments
229
 
230
  # ignore the initial comma to solve the recursion problem
 
236
 
237
  source_text = seg.source_text
238
  translation = seg.translation
239
+
240
+ # split the text based on commas
241
  src_commas = [m.start() for m in re.finditer(',', source_text)]
242
  trans_commas = [m.start() for m in re.finditer(',', translation)]
243
  if len(src_commas) != 0:
 
254
  else:
255
  trans_split_idx = len(translation)//2
256
 
257
+ # split the time duration based on text length
258
+ time_split_ratio = trans_split_idx/(len(seg.translation) - 1)
259
+
260
  src_seg1 = source_text[:src_split_idx]
261
  src_seg2 = source_text[src_split_idx:]
262
  trans_seg1 = translation[:trans_split_idx]
263
  trans_seg2 = translation[trans_split_idx:]
264
+
265
  start_seg1 = seg.start
266
+ end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)*time_split_ratio
267
  end_seg2 = seg.end
268
+
269
  seg1_dict = {}
270
  seg1_dict['text'] = src_seg1
271
  seg1_dict['start'] = start_seg1
 
281
  seg2.translation = trans_seg2
282
 
283
  result_list = []
284
+ if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold:
285
+ result_list += self.split_seg(seg1, text_threshold, time_threshold)
286
  else:
287
  result_list.append(seg1)
288
 
289
+ if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
290
+ result_list += self.split_seg(seg2, text_threshold, time_threshold)
291
  else:
292
  result_list.append(seg2)
293
 
294
  return result_list
295
 
296
 
297
+ def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
298
  # DEPRECATED
299
+ # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
300
  segments = []
301
  for seg in self.segments:
302
+ if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
303
+ seg_list = self.split_seg(seg, text_threshold, time_threshold)
304
  segments += seg_list
305
  else:
306
  segments.append(seg)
 
309
 
310
  pass
311
 
312
+ def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
313
+ # if sentence length >= text_threshold, split this segments to two
314
  start_seg_id = range[0]
315
  end_seg_id = range[1]
316
  extra_len = 0
317
  segments = []
318
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
319
+ if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
320
+ seg_list = self.split_seg(seg, text_threshold, time_threshold)
321
  segments += seg_list
322
  extra_len += len(seg_list) - 1
323
  else:
324
  segments.append(seg)
325
 
326
  self.segments[start_seg_id-1:end_seg_id] = segments
 
327
  return extra_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  def correct_with_force_term(self):
330
  ## force term correction
 
 
331
 
332
  # load term dictionary
333
  with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
 
416
  real_word = word.lower()
417
  n = 0
418
  return real_word, len(word)+n
419
+
420
+
421
+ ## WRITE AND READ FUNCTIONS ##
422
+
423
+ def get_source_only(self):
424
+ # return a string with pure source text
425
+ result = ""
426
+ for i, seg in enumerate(self.segments):
427
+ result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
428
+
429
+ return result
430
 
431
+ def reform_src_str(self):
432
+ result = ""
433
+ for i, seg in enumerate(self.segments):
434
+ result += f'{i+1}\n'
435
+ result += str(seg)
436
+ return result
437
+
438
+ def reform_trans_str(self):
439
+ result = ""
440
+ for i, seg in enumerate(self.segments):
441
+ result += f'{i+1}\n'
442
+ result += seg.get_trans_str()
443
+ return result
444
+
445
+ def form_bilingual_str(self):
446
+ result = ""
447
+ for i, seg in enumerate(self.segments):
448
+ result += f'{i+1}\n'
449
+ result += seg.get_bilingual_str()
450
+ return result
451
+
452
+ def write_srt_file_src(self, path:str):
453
+ # write srt file to path
454
+ with open(path, "w", encoding='utf-8') as f:
455
+ f.write(self.reform_src_str())
456
+ pass
457
+
458
+ def write_srt_file_translate(self, path:str):
459
+ with open(path, "w", encoding='utf-8') as f:
460
+ f.write(self.reform_trans_str())
461
+ pass
462
+
463
+ def write_srt_file_bilingual(self, path:str):
464
+ with open(path, "w", encoding='utf-8') as f:
465
+ f.write(self.form_bilingual_str())
466
+ pass
467
+
468
  def realtime_write_srt(self,path,range,length, idx):
469
+ # DEPRECATED
470
  start_seg_id = range[0]
471
  end_seg_id = range[1]
472
  with open(path, "a", encoding='utf-8') as f:
 
481
  pass
482
 
483
  def realtime_bilingual_write_srt(self,path,range, length,idx):
484
+ # DEPRECATED
485
  start_seg_id = range[0]
486
  end_seg_id = range[1]
487
  with open(path, "a", encoding='utf-8') as f:
 
490
  if i>=range[1] + length :break
491
  f.write(f'{i+idx}\n')
492
  f.write(seg.get_bilingual_str())
493
+ pass
494
+
495
+
pipeline.py CHANGED
@@ -49,7 +49,7 @@ if args.video_name == 'placeholder' :
49
  elif args.audio_file is not None:
50
  VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
51
  elif args.srt_file is not None:
52
- VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[:-1]
53
  else:
54
  VIDEO_NAME = args.video_name
55
  else:
@@ -257,6 +257,7 @@ for sentence, range in tqdm(zip(script_arr, range_arr)):
257
  # srt.realtime_bilingual_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt",range, add_length,segidx)
258
 
259
  srt.check_len_and_split()
 
260
  srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
261
  srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
262
 
 
49
  elif args.audio_file is not None:
50
  VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
51
  elif args.srt_file is not None:
52
+ VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[0]
53
  else:
54
  VIDEO_NAME = args.video_name
55
  else:
 
257
  # srt.realtime_bilingual_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt",range, add_length,segidx)
258
 
259
  srt.check_len_and_split()
260
+ srt.remove_trans_punctuation()
261
  srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
262
  srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
263