import os import re from copy import copy, deepcopy from csv import reader from datetime import timedelta import logging import openai from tqdm import tqdm class SrtSegment(object): def __init__(self, *args) -> None: if isinstance(args[0], dict): segment = args[0] self.start = segment['start'] self.end = segment['end'] self.start_ms = int((segment['start'] * 100) % 100 * 10) self.end_ms = int((segment['end'] * 100) % 100 * 10) if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp self.end_ms += 500 self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms) self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms) if self.start_ms == 0: self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000' else: self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \ str(self.start_time).split('.')[1][:3] if self.end_ms == 0: self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000' else: self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][ :3] self.source_text = segment['text'].lstrip() self.duration = f"{self.start_time_str} --> {self.end_time_str}" self.translation = "" elif isinstance(args[0], list): self.source_text = args[0][2] self.duration = args[0][1] self.start_time_str = self.duration.split(" --> ")[0] self.end_time_str = self.duration.split(" --> ")[1] # parse the time to float self.start_ms = int(self.start_time_str.split(',')[1]) / 10 self.end_ms = int(self.end_time_str.split(',')[1]) / 10 start_list = self.start_time_str.split(',')[0].split(':') self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100 end_list = self.end_time_str.split(',')[0].split(':') self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100 self.translation = "" def merge_seg(self, seg): """ Merge the segment seg with the current segment in place. :param seg: Another segment that is strictly next to current one. :return: None """ # assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments." self.source_text += f' {seg.source_text}' self.translation += f' {seg.translation}' self.end_time_str = seg.end_time_str self.end = seg.end self.end_ms = seg.end_ms self.duration = f"{self.start_time_str} --> {self.end_time_str}" def __add__(self, other): """ Merge the segment seg with the current segment, and return the new constructed segment. No in-place modification. This is used for '+' operator. :param other: Another segment that is strictly next to added segment. :return: new segment of the two sub-segments """ result = deepcopy(self) result.merge_seg(other) return result def remove_trans_punc(self) -> None: """ remove CN punctuations in translation text :return: None """ punc_cn = ",。!?" translator = str.maketrans(punc_cn, ' ' * len(punc_cn)) self.translation = self.translation.translate(translator) def __str__(self) -> str: return f'{self.duration}\n{self.source_text}\n\n' def get_trans_str(self) -> str: return f'{self.duration}\n{self.translation}\n\n' def get_bilingual_str(self) -> str: return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' class SrtScript(object): def __init__(self, segments) -> None: self.segments = [SrtSegment(seg) for seg in segments] @classmethod def parse_from_srt_file(cls, path: str): with open(path, 'r', encoding="utf-8") as f: script_lines = [line.rstrip() for line in f.readlines()] segments = [] for i in range(0, len(script_lines), 4): segments.append(list(script_lines[i:i + 4])) return cls(segments) def merge_segs(self, idx_list) -> SrtSegment: """ Merge entire segment list to a single segment :param idx_list: List of index to merge :return: Merged list """ if not idx_list: raise NotImplementedError('Empty idx_list') seg_result = deepcopy(self.segments[idx_list[0]]) if len(idx_list) == 1: return seg_result for idx in range(1, len(idx_list)): seg_result += self.segments[idx_list[idx]] return seg_result def form_whole_sentence(self): """ Concatenate or Strip sentences and reconstruct segments list. This is because of improper segmentation from openai-whisper. :return: None """ logging.info("Forming whole sentences...") merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] sentence = [] # Get each entire sentence of distinct segments, fill indices to merge_list for i, seg in enumerate(self.segments): if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text: sentence.append(i) merge_list.append(sentence) sentence = [] else: sentence.append(i) # Reconstruct segments, each with an entire sentence segments = [] for idx_list in merge_list: if len(idx_list) > 1: logging.info("merging segments: %s", idx_list) segments.append(self.merge_segs(idx_list)) self.segments = segments def remove_trans_punctuation(self): """ Post-process: remove all punc after translation and split :return: None """ for i, seg in enumerate(self.segments): seg.remove_trans_punc() logging.info("Removed punctuation in translation.") def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None): start_seg_id = id_range[0] end_seg_id = id_range[1] src_text = "" for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): src_text += seg.source_text src_text += '\n\n' def inner_func(target, input_str): response = openai.ChatCompletion.create( # model=model, model="gpt-4", messages=[ {"role": "system", "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"}, {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"}, {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)} ], temperature=0.15 ) return response['choices'][0]['message']['content'].strip() lines = translate.split('\n\n') if len(lines) < (end_seg_id - start_seg_id + 1): count = 0 solved = True while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1): count += 1 print("Solving Unmatched Lines|iteration {}".format(count)) flag = True while flag: flag = False # print("translate:") # print(translate) try: # print("target") # print(end_seg_id - start_seg_id + 1) translate = inner_func(end_seg_id - start_seg_id + 1, translate) except Exception as e: print("An error has occurred during solving unmatched lines:", e) print("Retrying...") flag = True lines = translate.split('\n') # print("result") # print(len(lines)) if len(lines) < (end_seg_id - start_seg_id + 1): solved = False print("Failed Solving unmatched lines, Manually parse needed") if not os.path.exists("./logs"): os.mkdir("./logs") if video_link: log_file = "./logs/log_link.csv" log_exist = os.path.exists(log_file) with open(log_file, "a") as log: if not log_exist: log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n") log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str( len(self.segments)) + ',' + video_link + "\n") else: log_file = "./logs/log_name.csv" log_exist = os.path.exists(log_file) with open(log_file, "a") as log: if not log_exist: log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n") log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str( len(self.segments)) + ',' + video_name + "\n") print(lines) for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): # naive way to due with merge translation problem # TODO: need a smarter solution if i < len(lines): if "Note:" in lines[i]: # to avoid note lines.remove(lines[i]) max_num -= 1 if i == len(lines) - 1: break if lines[i][0] in [' ', '\n']: lines[i] = lines[i][1:] seg.translation = lines[i] def split_seg(self, seg, text_threshold, time_threshold): # evenly split seg to 2 parts and add new seg into self.segments # ignore the initial comma to solve the recursion problem if len(seg.source_text) > 2: if seg.source_text[:2] == ', ': seg.source_text = seg.source_text[2:] if seg.translation[0] == ',': seg.translation = seg.translation[1:] source_text = seg.source_text translation = seg.translation # split the text based on commas src_commas = [m.start() for m in re.finditer(',', source_text)] trans_commas = [m.start() for m in re.finditer(',', translation)] if len(src_commas) != 0: src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[ len(src_commas) // 2 - 1] else: src_space = [m.start() for m in re.finditer(' ', source_text)] if len(src_space) > 0: src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[ len(src_space) // 2 - 1] else: src_split_idx = 0 if len(trans_commas) != 0: trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[ len(trans_commas) // 2 - 1] else: trans_split_idx = len(translation) // 2 # to avoid split English word for i in range(trans_split_idx, len(translation)): if not translation[i].encode('utf-8').isalpha(): trans_split_idx = i break # split the time duration based on text length time_split_ratio = trans_split_idx / (len(seg.translation) - 1) src_seg1 = source_text[:src_split_idx] src_seg2 = source_text[src_split_idx:] trans_seg1 = translation[:trans_split_idx] trans_seg2 = translation[trans_split_idx:] start_seg1 = seg.start end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio end_seg2 = seg.end seg1_dict = {} seg1_dict['text'] = src_seg1 seg1_dict['start'] = start_seg1 seg1_dict['end'] = end_seg1 seg1 = SrtSegment(seg1_dict) seg1.translation = trans_seg1 seg2_dict = {} seg2_dict['text'] = src_seg2 seg2_dict['start'] = start_seg2 seg2_dict['end'] = end_seg2 seg2 = SrtSegment(seg2_dict) seg2.translation = trans_seg2 result_list = [] if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold: result_list += self.split_seg(seg1, text_threshold, time_threshold) else: result_list.append(seg1) if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold: result_list += self.split_seg(seg2, text_threshold, time_threshold) else: result_list.append(seg2) return result_list def check_len_and_split(self, text_threshold=30, time_threshold=1.0): # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two logging.info("performing check_len_and_split") segments = [] for i, seg in enumerate(self.segments): if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold: seg_list = self.split_seg(seg, text_threshold, time_threshold) logging.info("splitting segment {} in to {} parts".format(i + 1, len(seg_list))) segments += seg_list else: segments.append(seg) self.segments = segments logging.info("check_len_and_split finished") pass def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0): # DEPRECATED # if sentence length >= text_threshold, split this segments to two start_seg_id = range[0] end_seg_id = range[1] extra_len = 0 segments = [] for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold: seg_list = self.split_seg(seg, text_threshold, time_threshold) segments += seg_list extra_len += len(seg_list) - 1 else: segments.append(seg) self.segments[start_seg_id - 1:end_seg_id] = segments return extra_len def correct_with_force_term(self): ## force term correction logging.info("performing force term correction") # load term dictionary with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f: term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)} keywords = list(term_enzh_dict.keys()) keywords.sort(key=lambda x: len(x), reverse=True) for word in keywords: for i, seg in enumerate(self.segments): if word in seg.source_text.lower(): seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)), seg.source_text, flags=re.IGNORECASE) logging.info( "replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format( i + 1)) logging.info("source text becomes: " + seg.source_text) comp_dict = [] def fetchfunc(self, word, threshold): import enchant result = word distance = 0 threshold = threshold * len(word) if len(self.comp_dict) == 0: with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f: self.comp_dict = {rows[0]: 1 for rows in reader(f)} temp = "" for matched in self.comp_dict: if (" " in matched and " " in word) or (" " not in matched and " " not in word): if enchant.utils.levenshtein(word, matched) < enchant.utils.levenshtein(word, temp): temp = matched if enchant.utils.levenshtein(word, temp) < threshold: distance = enchant.utils.levenshtein(word, temp) result = temp return distance, result def extract_words(self, sentence, n): # this function split the sentence to chunks by n of words # e.g. sentence: "this, is a sentence", n = 2 # result: ["this,", "is", "a", ["sentence"], ["this,", "is"], "is a", "a sentence"] words = sentence.split() res = [] for j in range(n, 0, -1): res += [words[i:i + j] for i in range(len(words) - j + 1)] return res def spell_check_term(self): logging.info("performing spell check") import enchant dict = enchant.Dict('en_US') term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt') for seg in tqdm(self.segments): ready_words = self.extract_words(seg.source_text, 2) for i in range(len(ready_words)): word_list = ready_words[i] word, real_word, pos = self.get_real_word(word_list) if not dict.check(real_word) and not term_spellDict.check(real_word): distance, correct_term = self.fetchfunc(real_word, 0.3) if distance != 0: seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE) logging.info( "replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance)) def get_real_word(self, word_list: list): word = "" for w in word_list: word += f"{w} " word = word[:-1] # "this, is" if word[-2:] == ".\n": real_word = word[:-2].lower() n = -2 elif word[-1:] in [".", "\n", ",", "!", "?"]: real_word = word[:-1].lower() n = -1 else: real_word = word.lower() n = 0 return word, real_word, len(word) + n ## WRITE AND READ FUNCTIONS ## def get_source_only(self): # return a string with pure source text result = "" for i, seg in enumerate(self.segments): result += f'{seg.source_text}\n\n\n' # f'SENTENCE {i+1}: {seg.source_text}\n\n\n' return result def reform_src_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i + 1}\n' result += str(seg) return result def reform_trans_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i + 1}\n' result += seg.get_trans_str() return result def form_bilingual_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i + 1}\n' result += seg.get_bilingual_str() return result def write_srt_file_src(self, path: str): # write srt file to path with open(path, "w", encoding='utf-8') as f: f.write(self.reform_src_str()) pass def write_srt_file_translate(self, path: str): logging.info("writing to " + path) with open(path, "w", encoding='utf-8') as f: f.write(self.reform_trans_str()) pass def write_srt_file_bilingual(self, path: str): logging.info("writing to " + path) with open(path, "w", encoding='utf-8') as f: f.write(self.form_bilingual_str()) pass def realtime_write_srt(self, path, range, length, idx): # DEPRECATED start_seg_id = range[0] end_seg_id = range[1] with open(path, "a", encoding='utf-8') as f: # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id+length]): # f.write(f'{i+idx}\n') # f.write(seg.get_trans_str()) for i, seg in enumerate(self.segments): if i < range[0] - 1: continue if i >= range[1] + length: break f.write(f'{i + idx}\n') f.write(seg.get_trans_str()) pass def realtime_bilingual_write_srt(self, path, range, length, idx): # DEPRECATED start_seg_id = range[0] end_seg_id = range[1] with open(path, "a", encoding='utf-8') as f: for i, seg in enumerate(self.segments): if i < range[0] - 1: continue if i >= range[1] + length: break f.write(f'{i + idx}\n') f.write(seg.get_bilingual_str()) pass