ViDove / SRT.py
Eason Lu
solve empty time stamp;working on split
f1a218d
raw
history blame
8.86 kB
from datetime import timedelta
from csv import reader
import re
class SRT_segment(object):
def __init__(self, *args) -> None:
if isinstance(args[0], dict):
segment = args[0]
self.start = segment['start']
self.end = segment['end']
self.start_ms = int((segment['start']*100)%100*10)
self.end_ms = int((segment['end']*100)%100*10)
if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
self.end_ms+=500
self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
if self.start_ms == 0:
self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000'
else:
self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+self.start_time.split('.')[1][:3]
if self.end_ms == 0:
self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000'
else:
self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+self.end_time.split('.')[1][:3]
self.source_text = segment['text'][1:]
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
self.translation = ""
elif isinstance(args[0], list):
self.source_text = args[0][2]
self.duration = args[0][1]
self.start_time_str = self.duration.split(" --> ")[0]
self.end_time_str = self.duration.split(" --> ")[1]
self.translation = ""
def merge_seg(self, seg):
self.source_text += seg.source_text
self.translation += seg.translation
self.end_time_str = seg.end_time_str
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
pass
def __str__(self) -> str:
return f'{self.duration}\n{self.source_text}\n\n'
def get_trans_str(self) -> str:
return f'{self.duration}\n{self.translation}\n\n'
def get_bilingual_str(self) -> str:
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
class SRT_script():
def __init__(self, segments) -> None:
self.segments = []
for seg in segments:
srt_seg = SRT_segment(seg)
self.segments.append(srt_seg)
@classmethod
def parse_from_srt_file(cls, path:str):
with open(path, 'r', encoding="utf-8") as f:
script_lines = f.read().splitlines()
segments = []
for i in range(len(script_lines)):
if i % 4 == 0:
segments.append(list(script_lines[i:i+4]))
return cls(segments)
def merge_segs(self, idx_list) -> SRT_segment:
final_seg = self.segments[idx_list[0]]
if len(idx_list) == 1:
return final_seg
for idx in range(1, len(idx_list)):
final_seg.merge_seg(self.segments[idx_list[idx]])
return final_seg
def form_whole_sentence(self):
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
sentence = []
for i, seg in enumerate(self.segments):
if seg.source_text[-1] == '.':
sentence.append(i)
merge_list.append(sentence)
sentence = []
else:
sentence.append(i)
segments = []
for idx_list in merge_list:
segments.append(self.merge_segs(idx_list))
self.segments = segments # need memory release?
def set_translation(self, translate:str, id_range:tuple):
start_seg_id = id_range[0]
end_seg_id = id_range[1]
lines = translate.split('\n\n')
if len(lines) != (end_seg_id - start_seg_id + 1):
print(id_range)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
print(seg.source_text)
print(translate)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# naive way to due with merge translation problem
# TODO: need a smarter solution
if i < len(lines):
if "(Note:" in lines[i]: # to avoid note
lines.remove(lines[i])
if i == len(lines) - 1:
break
try:
seg.translation = lines[i].split(":" or ": ")[1]
except:
seg.translation = lines[i]
#print(lines[i])
pass
def split_seg(self, seg_idx):
# TODO: evenly split seg to 2 parts and add new seg into self.segments
seg = self.segments[seg_idx]
source_text = seg.source_text
translation = seg.translation
src_commas = [m.start() for m in re.finditer(',', source_text)]
trans_commas = [m.start() for m in re.finditer(',', translation)]
src_split_idx = src_commas[len(src_commas)//2 + 1] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2]
trans_split_idx = trans_commas[len(src_commas)//2 + 1] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2]
src_seg1 = source_text[:src_split_idx]
src_seg2 = source_text[src_split_idx+1:]
trans_seg1 = translation[:trans_split_idx]
trans_seg2 = translation[trans_split_idx+1:]
start_seg1 = seg.start
end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2
end_seg2 = seg.end
seg1_dict = {}
seg1_dict['text'] = src_seg1
seg1_dict['start'] = start_seg1
seg1_dict['end'] = end_seg1
seg1 = SRT_segment(seg1_dict)
seg1.translation = trans_seg1
seg2_dict = {}
seg2_dict['text'] = src_seg2
seg2_dict['start'] = start_seg2
seg2_dict['end'] = end_seg2
seg2 = SRT_segment(seg2_dict)
seg2.translation = trans_seg2
pass
def check_len_and_split(self, threshold):
# TODO: if sentence length >= threshold, split this segments to two
pass
def get_source_only(self):
# return a string with pure source text
result = ""
for i, seg in enumerate(self.segments):
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
return result
def reform_src_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += str(seg)
return result
def reform_trans_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_trans_str()
return result
def form_bilingual_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_bilingual_str()
return result
def write_srt_file_src(self, path:str):
# write srt file to path
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_src_str())
pass
def write_srt_file_translate(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_trans_str())
pass
def write_srt_file_bilingual(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.form_bilingual_str())
pass
def correct_with_force_term(self):
## force term correction
# TODO: shortcut translation i.e. VA, ob
# TODO: variety of translation
# load term dictionary
with open("finetune_data/dict.csv",'r', encoding='utf-8') as f:
csv_reader = reader(f)
term_dict = {rows[0]:rows[1] for rows in csv_reader}
# change term
for seg in self.segments:
ready_words = re.sub('\n', '\n ', seg.source_text).split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
if word[-2:] == ".\n" :
if word[:-2].lower() in term_dict :
new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' '
ready_words[i] = new_word
else:
ready_words[i] = word + ' '
elif word.lower() in term_dict :
new_word = word.replace(word,term_dict.get(word.lower())) + ' '
ready_words[i] = new_word
else :
ready_words[i]= word + ' '
seg.source_text = re.sub('\n ', '\n', "".join(ready_words))
pass