from datetime import timedelta from csv import reader from datetime import datetime import re class SRT_segment(object): def __init__(self, *args) -> None: if isinstance(args[0], dict): segment = args[0] self.start = segment['start'] self.end = segment['end'] self.start_ms = int((segment['start']*100)%100*10) self.end_ms = int((segment['end']*100)%100*10) if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp self.end_ms+=500 self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms) self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms) if self.start_ms == 0: self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000' else: self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+str(self.start_time).split('.')[1][:3] if self.end_ms == 0: self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000' else: self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+str(self.end_time).split('.')[1][:3] self.source_text = segment['text'] self.duration = f"{self.start_time_str} --> {self.end_time_str}" self.translation = "" elif isinstance(args[0], list): self.source_text = args[0][2] self.duration = args[0][1] self.start_time_str = self.duration.split(" --> ")[0] self.end_time_str = self.duration.split(" --> ")[1] # parse the time to float self.start_ms = int(self.start_time_str.split(',')[1])/10 self.end_ms = int(self.end_time_str.split(',')[1])/10 start_list = self.start_time_str.split(',')[0].split(':') self.start = int(start_list[0])*3600 + int(start_list[1])*60 + int(start_list[2]) + self.start_ms/100 end_list = self.end_time_str.split(',')[0].split(':') self.end = int(end_list[0])*3600 + int(end_list[1])*60 + int(end_list[2]) + self.end_ms/100 self.translation = "" def merge_seg(self, seg): self.source_text += seg.source_text self.translation += seg.translation self.end_time_str = seg.end_time_str self.duration = f"{self.start_time_str} --> {self.end_time_str}" pass def __str__(self) -> str: return f'{self.duration}\n{self.source_text}\n\n' def get_trans_str(self) -> str: return f'{self.duration}\n{self.translation}\n\n' def get_bilingual_str(self) -> str: return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' class SRT_script(): def __init__(self, segments) -> None: self.segments = [] for seg in segments: srt_seg = SRT_segment(seg) self.segments.append(srt_seg) @classmethod def parse_from_srt_file(cls, path:str): with open(path, 'r', encoding="utf-8") as f: script_lines = f.read().splitlines() segments = [] for i in range(len(script_lines)): if i % 4 == 0: segments.append(list(script_lines[i:i+4])) return cls(segments) def merge_segs(self, idx_list) -> SRT_segment: final_seg = self.segments[idx_list[0]] if len(idx_list) == 1: return final_seg for idx in range(1, len(idx_list)): final_seg.merge_seg(self.segments[idx_list[idx]]) return final_seg def form_whole_sentence(self): merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] sentence = [] for i, seg in enumerate(self.segments): if seg.source_text[-1] == '.': sentence.append(i) merge_list.append(sentence) sentence = [] else: sentence.append(i) segments = [] for idx_list in merge_list: segments.append(self.merge_segs(idx_list)) self.segments = segments # need memory release? def set_translation(self, translate:str, id_range:tuple): start_seg_id = id_range[0] end_seg_id = id_range[1] lines = translate.split('\n\n') if len(lines) != (end_seg_id - start_seg_id + 1): print(id_range) for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): print(seg.source_text) print(translate) for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): # naive way to due with merge translation problem # TODO: need a smarter solution if i < len(lines): if "(Note:" in lines[i]: # to avoid note lines.remove(lines[i]) if i == len(lines) - 1: break try: seg.translation = lines[i].split(":" or ": ")[1] except: seg.translation = lines[i] #print(lines[i]) pass def split_seg(self, seg, threshold): # TODO: evenly split seg to 2 parts and add new seg into self.segments source_text = seg.source_text translation = seg.translation src_commas = [m.start() for m in re.finditer(',', source_text)] trans_commas = [m.start() for m in re.finditer(',', translation)] if len(src_commas) != 0: src_split_idx = src_commas[len(src_commas)//2] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2 - 1] else: src_space = [m.start() for m in re.finditer(' ', source_text)] src_split_idx = src_space[len(src_space)//2] if len(src_space) % 2 == 1 else src_space[len(src_space)//2 - 1] if len(trans_commas) != 0: trans_split_idx = trans_commas[len(src_commas)//2] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2 - 1] else: trans_split_idx = len(translation)//2 src_seg1 = source_text[:src_split_idx] src_seg2 = source_text[src_split_idx:] trans_seg1 = translation[:trans_split_idx] trans_seg2 = translation[trans_split_idx:] start_seg1 = seg.start end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2 end_seg2 = seg.end seg1_dict = {} seg1_dict['text'] = src_seg1 seg1_dict['start'] = start_seg1 seg1_dict['end'] = end_seg1 seg1 = SRT_segment(seg1_dict) seg1.translation = trans_seg1 seg2_dict = {} seg2_dict['text'] = src_seg2 seg2_dict['start'] = start_seg2 seg2_dict['end'] = end_seg2 seg2 = SRT_segment(seg2_dict) seg2.translation = trans_seg2 result_list = [] if len(seg1.translation) > threshold: result_list += self.split_seg(seg1, threshold) else: result_list.append(seg1) if len(seg2.translation) > threshold: result_list += self.split_seg(seg2, threshold) else: result_list.append(seg2) return result_list def check_len_and_split(self, threshold=30): # TODO: if sentence length >= threshold, split this segments to two segments = [] for seg in self.segments: if len(seg.translation) > threshold: seg_list = self.split_seg(seg, threshold) segments += seg_list else: segments.append(seg) self.segments = segments pass def get_source_only(self): # return a string with pure source text result = "" for i, seg in enumerate(self.segments): result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n' return result def reform_src_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i+1}\n' result += str(seg) return result def reform_trans_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i+1}\n' result += seg.get_trans_str() return result def form_bilingual_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i+1}\n' result += seg.get_bilingual_str() return result def write_srt_file_src(self, path:str): # write srt file to path with open(path, "w", encoding='utf-8') as f: f.write(self.reform_src_str()) pass def write_srt_file_translate(self, path:str): with open(path, "w", encoding='utf-8') as f: f.write(self.reform_trans_str()) pass def write_srt_file_bilingual(self, path:str): with open(path, "w", encoding='utf-8') as f: f.write(self.form_bilingual_str()) pass def correct_with_force_term(self): ## force term correction # TODO: shortcut translation i.e. VA, ob # TODO: variety of translation # load term dictionary with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f: term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)} # change term for seg in self.segments: ready_words = seg.source_text.split(" ") for i in range(len(ready_words)): word = ready_words[i] [real_word, pos] = self.get_real_word(word) if real_word in term_enzh_dict: new_word = word.replace(word[:pos], term_enzh_dict.get(real_word)) else: new_word = word ready_words[i] = new_word seg.source_text = " ".join(ready_words) pass def spell_check_term(self): ## known bug: I've will be replaced because i've is not in the dict import enchant dict = enchant.Dict('en_US') term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt') for seg in self.segments: ready_words = seg.source_text.split(" ") for i in range(len(ready_words)): word = ready_words[i] [real_word, pos] = self.get_real_word(word) if not dict.check(real_word): suggest = term_spellDict.suggest(real_word) if suggest: # relax spell check new_word = word.replace(word[:pos],suggest[0]) else: new_word = word ready_words[i] = new_word seg.source_text = " ".join(ready_words) pass def spell_correction(self, word:str, arg:int): try: arg in [0,1] except ValueError: print('only 0 or 1 for argument') def uncover(word:str): if word[-2:] == ".\n": real_word = word[:-2].lower() n = -2 elif word[-1:] in [".", "\n", ",", "!", "?"]: real_word = word[:-1].lower() n = -1 else: real_word = word.lower() n = 0 return real_word, len(word)+n real_word = uncover(word)[0] pos = uncover(word)[1] new_word = word if arg == 0: # term translate mode with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f: term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)} if real_word in term_enzh_dict: new_word = word.replace(word[:pos], term_enzh_dict.get(real_word)) elif arg == 1: # term spell check mode import enchant dict = enchant.Dict('en_US') term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt') if not dict.check(real_word): if term_spellDict.suggest(real_word): # relax spell check new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0]) return new_word def get_real_word(self, word:str): if word[-2:] == ".\n": real_word = word[:-2].lower() n = -2 elif word[-1:] in [".", "\n", ",", "!", "?"]: real_word = word[:-1].lower() n = -1 else: real_word = word.lower() n = 0 return real_word, len(word)+n