Spaces:
Sleeping
Sleeping
Eason Lu
commited on
Commit
•
9b3283d
1
Parent(s):
84e3cbf
solve split problem; add punctuation remove
Browse filesFormer-commit-id: 8125cf63dcc38aa00e2eb10c513daa1d2ca3aeb5
- SRT.py +94 -68
- pipeline.py +2 -1
SRT.py
CHANGED
@@ -51,9 +51,18 @@ class SRT_segment(object):
|
|
51 |
self.source_text += seg.source_text
|
52 |
self.translation += seg.translation
|
53 |
self.end_time_str = seg.end_time_str
|
|
|
|
|
54 |
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
|
55 |
pass
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def __str__(self) -> str:
|
58 |
return f'{self.duration}\n{self.source_text}\n\n'
|
59 |
|
@@ -96,7 +105,7 @@ class SRT_script():
|
|
96 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
97 |
sentence = []
|
98 |
for i, seg in enumerate(self.segments):
|
99 |
-
if seg.source_text[-1]
|
100 |
sentence.append(i)
|
101 |
merge_list.append(sentence)
|
102 |
sentence = []
|
@@ -108,7 +117,11 @@ class SRT_script():
|
|
108 |
segments.append(self.merge_segs(idx_list))
|
109 |
|
110 |
self.segments = segments # need memory release?
|
111 |
-
|
|
|
|
|
|
|
|
|
112 |
|
113 |
def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
|
114 |
start_seg_id = id_range[0]
|
@@ -197,10 +210,12 @@ class SRT_script():
|
|
197 |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
198 |
# naive way to due with merge translation problem
|
199 |
# TODO: need a smarter solution
|
200 |
-
|
201 |
-
if i <
|
202 |
if "(Note:" in lines[i]: # to avoid note
|
|
|
203 |
lines.remove(lines[i])
|
|
|
204 |
if i == len(lines) - 1:
|
205 |
break
|
206 |
try:
|
@@ -209,7 +224,7 @@ class SRT_script():
|
|
209 |
seg.translation = lines[i]
|
210 |
|
211 |
|
212 |
-
def split_seg(self, seg,
|
213 |
# evenly split seg to 2 parts and add new seg into self.segments
|
214 |
|
215 |
# ignore the initial comma to solve the recursion problem
|
@@ -221,6 +236,8 @@ class SRT_script():
|
|
221 |
|
222 |
source_text = seg.source_text
|
223 |
translation = seg.translation
|
|
|
|
|
224 |
src_commas = [m.start() for m in re.finditer(',', source_text)]
|
225 |
trans_commas = [m.start() for m in re.finditer(',', translation)]
|
226 |
if len(src_commas) != 0:
|
@@ -237,13 +254,18 @@ class SRT_script():
|
|
237 |
else:
|
238 |
trans_split_idx = len(translation)//2
|
239 |
|
|
|
|
|
|
|
240 |
src_seg1 = source_text[:src_split_idx]
|
241 |
src_seg2 = source_text[src_split_idx:]
|
242 |
trans_seg1 = translation[:trans_split_idx]
|
243 |
trans_seg2 = translation[trans_split_idx:]
|
|
|
244 |
start_seg1 = seg.start
|
245 |
-
end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)
|
246 |
end_seg2 = seg.end
|
|
|
247 |
seg1_dict = {}
|
248 |
seg1_dict['text'] = src_seg1
|
249 |
seg1_dict['start'] = start_seg1
|
@@ -259,26 +281,26 @@ class SRT_script():
|
|
259 |
seg2.translation = trans_seg2
|
260 |
|
261 |
result_list = []
|
262 |
-
if len(seg1.translation) >
|
263 |
-
result_list += self.split_seg(seg1,
|
264 |
else:
|
265 |
result_list.append(seg1)
|
266 |
|
267 |
-
if len(seg2.translation) >
|
268 |
-
result_list += self.split_seg(seg2,
|
269 |
else:
|
270 |
result_list.append(seg2)
|
271 |
|
272 |
return result_list
|
273 |
|
274 |
|
275 |
-
def check_len_and_split(self,
|
276 |
# DEPRECATED
|
277 |
-
# if sentence length >= threshold, split this segments to two
|
278 |
segments = []
|
279 |
for seg in self.segments:
|
280 |
-
if len(seg.translation) >
|
281 |
-
seg_list = self.split_seg(seg,
|
282 |
segments += seg_list
|
283 |
else:
|
284 |
segments.append(seg)
|
@@ -287,73 +309,25 @@ class SRT_script():
|
|
287 |
|
288 |
pass
|
289 |
|
290 |
-
def check_len_and_split_range(self, range,
|
291 |
-
# if sentence length >=
|
292 |
start_seg_id = range[0]
|
293 |
end_seg_id = range[1]
|
294 |
extra_len = 0
|
295 |
segments = []
|
296 |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
297 |
-
if len(seg.translation) >
|
298 |
-
seg_list = self.split_seg(seg,
|
299 |
segments += seg_list
|
300 |
extra_len += len(seg_list) - 1
|
301 |
else:
|
302 |
segments.append(seg)
|
303 |
|
304 |
self.segments[start_seg_id-1:end_seg_id] = segments
|
305 |
-
|
306 |
return extra_len
|
307 |
-
|
308 |
-
def get_source_only(self):
|
309 |
-
# return a string with pure source text
|
310 |
-
result = ""
|
311 |
-
for i, seg in enumerate(self.segments):
|
312 |
-
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
|
313 |
-
|
314 |
-
return result
|
315 |
-
|
316 |
-
def reform_src_str(self):
|
317 |
-
result = ""
|
318 |
-
for i, seg in enumerate(self.segments):
|
319 |
-
result += f'{i+1}\n'
|
320 |
-
result += str(seg)
|
321 |
-
return result
|
322 |
-
|
323 |
-
def reform_trans_str(self):
|
324 |
-
result = ""
|
325 |
-
for i, seg in enumerate(self.segments):
|
326 |
-
result += f'{i+1}\n'
|
327 |
-
result += seg.get_trans_str()
|
328 |
-
return result
|
329 |
-
|
330 |
-
def form_bilingual_str(self):
|
331 |
-
result = ""
|
332 |
-
for i, seg in enumerate(self.segments):
|
333 |
-
result += f'{i+1}\n'
|
334 |
-
result += seg.get_bilingual_str()
|
335 |
-
return result
|
336 |
-
|
337 |
-
def write_srt_file_src(self, path:str):
|
338 |
-
# write srt file to path
|
339 |
-
with open(path, "w", encoding='utf-8') as f:
|
340 |
-
f.write(self.reform_src_str())
|
341 |
-
pass
|
342 |
-
|
343 |
-
def write_srt_file_translate(self, path:str):
|
344 |
-
with open(path, "w", encoding='utf-8') as f:
|
345 |
-
f.write(self.reform_trans_str())
|
346 |
-
pass
|
347 |
-
|
348 |
-
def write_srt_file_bilingual(self, path:str):
|
349 |
-
with open(path, "w", encoding='utf-8') as f:
|
350 |
-
f.write(self.form_bilingual_str())
|
351 |
-
pass
|
352 |
|
353 |
def correct_with_force_term(self):
|
354 |
## force term correction
|
355 |
-
# TODO: shortcut translation i.e. VA, ob
|
356 |
-
# TODO: variety of translation
|
357 |
|
358 |
# load term dictionary
|
359 |
with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
|
@@ -442,8 +416,57 @@ class SRT_script():
|
|
442 |
real_word = word.lower()
|
443 |
n = 0
|
444 |
return real_word, len(word)+n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
def realtime_write_srt(self,path,range,length, idx):
|
|
|
447 |
start_seg_id = range[0]
|
448 |
end_seg_id = range[1]
|
449 |
with open(path, "a", encoding='utf-8') as f:
|
@@ -458,6 +481,7 @@ class SRT_script():
|
|
458 |
pass
|
459 |
|
460 |
def realtime_bilingual_write_srt(self,path,range, length,idx):
|
|
|
461 |
start_seg_id = range[0]
|
462 |
end_seg_id = range[1]
|
463 |
with open(path, "a", encoding='utf-8') as f:
|
@@ -466,4 +490,6 @@ class SRT_script():
|
|
466 |
if i>=range[1] + length :break
|
467 |
f.write(f'{i+idx}\n')
|
468 |
f.write(seg.get_bilingual_str())
|
469 |
-
pass
|
|
|
|
|
|
51 |
self.source_text += seg.source_text
|
52 |
self.translation += seg.translation
|
53 |
self.end_time_str = seg.end_time_str
|
54 |
+
self.end = seg.end
|
55 |
+
self.end_ms = seg.end_ms
|
56 |
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
|
57 |
pass
|
58 |
+
|
59 |
+
def remove_trans_punc(self):
|
60 |
+
# remove punctuations in translation text
|
61 |
+
self.translation = self.translation.replace(',', ' ')
|
62 |
+
self.translation = self.translation.replace('。', ' ')
|
63 |
+
self.translation = self.translation.replace('!', ' ')
|
64 |
+
self.translation = self.translation.replace('?', ' ')
|
65 |
+
|
66 |
def __str__(self) -> str:
|
67 |
return f'{self.duration}\n{self.source_text}\n\n'
|
68 |
|
|
|
105 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
106 |
sentence = []
|
107 |
for i, seg in enumerate(self.segments):
|
108 |
+
if seg.source_text[-1] in ['.', '!', '?']:
|
109 |
sentence.append(i)
|
110 |
merge_list.append(sentence)
|
111 |
sentence = []
|
|
|
117 |
segments.append(self.merge_segs(idx_list))
|
118 |
|
119 |
self.segments = segments # need memory release?
|
120 |
+
|
121 |
+
def remove_trans_punctuation(self):
|
122 |
+
# Post-process: remove all punc after translation and split
|
123 |
+
for i, seg in enumerate(self.segments):
|
124 |
+
seg.remove_trans_punc()
|
125 |
|
126 |
def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
|
127 |
start_seg_id = id_range[0]
|
|
|
210 |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
211 |
# naive way to due with merge translation problem
|
212 |
# TODO: need a smarter solution
|
213 |
+
max_num = len(lines)
|
214 |
+
if i < max_num:
|
215 |
if "(Note:" in lines[i]: # to avoid note
|
216 |
+
print(f'remove: {lines[i]}')
|
217 |
lines.remove(lines[i])
|
218 |
+
max_num -= 1
|
219 |
if i == len(lines) - 1:
|
220 |
break
|
221 |
try:
|
|
|
224 |
seg.translation = lines[i]
|
225 |
|
226 |
|
227 |
+
def split_seg(self, seg, text_threshold, time_threshold):
|
228 |
# evenly split seg to 2 parts and add new seg into self.segments
|
229 |
|
230 |
# ignore the initial comma to solve the recursion problem
|
|
|
236 |
|
237 |
source_text = seg.source_text
|
238 |
translation = seg.translation
|
239 |
+
|
240 |
+
# split the text based on commas
|
241 |
src_commas = [m.start() for m in re.finditer(',', source_text)]
|
242 |
trans_commas = [m.start() for m in re.finditer(',', translation)]
|
243 |
if len(src_commas) != 0:
|
|
|
254 |
else:
|
255 |
trans_split_idx = len(translation)//2
|
256 |
|
257 |
+
# split the time duration based on text length
|
258 |
+
time_split_ratio = trans_split_idx/(len(seg.translation) - 1)
|
259 |
+
|
260 |
src_seg1 = source_text[:src_split_idx]
|
261 |
src_seg2 = source_text[src_split_idx:]
|
262 |
trans_seg1 = translation[:trans_split_idx]
|
263 |
trans_seg2 = translation[trans_split_idx:]
|
264 |
+
|
265 |
start_seg1 = seg.start
|
266 |
+
end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)*time_split_ratio
|
267 |
end_seg2 = seg.end
|
268 |
+
|
269 |
seg1_dict = {}
|
270 |
seg1_dict['text'] = src_seg1
|
271 |
seg1_dict['start'] = start_seg1
|
|
|
281 |
seg2.translation = trans_seg2
|
282 |
|
283 |
result_list = []
|
284 |
+
if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold:
|
285 |
+
result_list += self.split_seg(seg1, text_threshold, time_threshold)
|
286 |
else:
|
287 |
result_list.append(seg1)
|
288 |
|
289 |
+
if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
|
290 |
+
result_list += self.split_seg(seg2, text_threshold, time_threshold)
|
291 |
else:
|
292 |
result_list.append(seg2)
|
293 |
|
294 |
return result_list
|
295 |
|
296 |
|
297 |
+
def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
|
298 |
# DEPRECATED
|
299 |
+
# if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
|
300 |
segments = []
|
301 |
for seg in self.segments:
|
302 |
+
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
|
303 |
+
seg_list = self.split_seg(seg, text_threshold, time_threshold)
|
304 |
segments += seg_list
|
305 |
else:
|
306 |
segments.append(seg)
|
|
|
309 |
|
310 |
pass
|
311 |
|
312 |
+
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
313 |
+
# if sentence length >= text_threshold, split this segments to two
|
314 |
start_seg_id = range[0]
|
315 |
end_seg_id = range[1]
|
316 |
extra_len = 0
|
317 |
segments = []
|
318 |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
|
319 |
+
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
|
320 |
+
seg_list = self.split_seg(seg, text_threshold, time_threshold)
|
321 |
segments += seg_list
|
322 |
extra_len += len(seg_list) - 1
|
323 |
else:
|
324 |
segments.append(seg)
|
325 |
|
326 |
self.segments[start_seg_id-1:end_seg_id] = segments
|
|
|
327 |
return extra_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
def correct_with_force_term(self):
|
330 |
## force term correction
|
|
|
|
|
331 |
|
332 |
# load term dictionary
|
333 |
with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
|
|
|
416 |
real_word = word.lower()
|
417 |
n = 0
|
418 |
return real_word, len(word)+n
|
419 |
+
|
420 |
+
|
421 |
+
## WRITE AND READ FUNCTIONS ##
|
422 |
+
|
423 |
+
def get_source_only(self):
|
424 |
+
# return a string with pure source text
|
425 |
+
result = ""
|
426 |
+
for i, seg in enumerate(self.segments):
|
427 |
+
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
|
428 |
+
|
429 |
+
return result
|
430 |
|
431 |
+
def reform_src_str(self):
|
432 |
+
result = ""
|
433 |
+
for i, seg in enumerate(self.segments):
|
434 |
+
result += f'{i+1}\n'
|
435 |
+
result += str(seg)
|
436 |
+
return result
|
437 |
+
|
438 |
+
def reform_trans_str(self):
|
439 |
+
result = ""
|
440 |
+
for i, seg in enumerate(self.segments):
|
441 |
+
result += f'{i+1}\n'
|
442 |
+
result += seg.get_trans_str()
|
443 |
+
return result
|
444 |
+
|
445 |
+
def form_bilingual_str(self):
|
446 |
+
result = ""
|
447 |
+
for i, seg in enumerate(self.segments):
|
448 |
+
result += f'{i+1}\n'
|
449 |
+
result += seg.get_bilingual_str()
|
450 |
+
return result
|
451 |
+
|
452 |
+
def write_srt_file_src(self, path:str):
|
453 |
+
# write srt file to path
|
454 |
+
with open(path, "w", encoding='utf-8') as f:
|
455 |
+
f.write(self.reform_src_str())
|
456 |
+
pass
|
457 |
+
|
458 |
+
def write_srt_file_translate(self, path:str):
|
459 |
+
with open(path, "w", encoding='utf-8') as f:
|
460 |
+
f.write(self.reform_trans_str())
|
461 |
+
pass
|
462 |
+
|
463 |
+
def write_srt_file_bilingual(self, path:str):
|
464 |
+
with open(path, "w", encoding='utf-8') as f:
|
465 |
+
f.write(self.form_bilingual_str())
|
466 |
+
pass
|
467 |
+
|
468 |
def realtime_write_srt(self,path,range,length, idx):
|
469 |
+
# DEPRECATED
|
470 |
start_seg_id = range[0]
|
471 |
end_seg_id = range[1]
|
472 |
with open(path, "a", encoding='utf-8') as f:
|
|
|
481 |
pass
|
482 |
|
483 |
def realtime_bilingual_write_srt(self,path,range, length,idx):
|
484 |
+
# DEPRECATED
|
485 |
start_seg_id = range[0]
|
486 |
end_seg_id = range[1]
|
487 |
with open(path, "a", encoding='utf-8') as f:
|
|
|
490 |
if i>=range[1] + length :break
|
491 |
f.write(f'{i+idx}\n')
|
492 |
f.write(seg.get_bilingual_str())
|
493 |
+
pass
|
494 |
+
|
495 |
+
|
pipeline.py
CHANGED
@@ -49,7 +49,7 @@ if args.video_name == 'placeholder' :
|
|
49 |
elif args.audio_file is not None:
|
50 |
VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
|
51 |
elif args.srt_file is not None:
|
52 |
-
VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[
|
53 |
else:
|
54 |
VIDEO_NAME = args.video_name
|
55 |
else:
|
@@ -257,6 +257,7 @@ for sentence, range in tqdm(zip(script_arr, range_arr)):
|
|
257 |
# srt.realtime_bilingual_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt",range, add_length,segidx)
|
258 |
|
259 |
srt.check_len_and_split()
|
|
|
260 |
srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
|
261 |
srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
|
262 |
|
|
|
49 |
elif args.audio_file is not None:
|
50 |
VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
|
51 |
elif args.srt_file is not None:
|
52 |
+
VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[0]
|
53 |
else:
|
54 |
VIDEO_NAME = args.video_name
|
55 |
else:
|
|
|
257 |
# srt.realtime_bilingual_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt",range, add_length,segidx)
|
258 |
|
259 |
srt.check_len_and_split()
|
260 |
+
srt.remove_trans_punctuation()
|
261 |
srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
|
262 |
srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
|
263 |
|