ViDove / SRT.py
JiaenLiu
in progress
9ad8b62
raw
history blame
17.2 kB
from datetime import timedelta
from csv import reader
from datetime import datetime
import re
import openai
from collections import deque
class SRT_segment(object):
def __init__(self, *args) -> None:
if isinstance(args[0], dict):
segment = args[0]
self.start = segment['start']
self.end = segment['end']
self.start_ms = int((segment['start']*100)%100*10)
self.end_ms = int((segment['end']*100)%100*10)
if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
self.end_ms+=500
self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
if self.start_ms == 0:
self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000'
else:
self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+str(self.start_time).split('.')[1][:3]
if self.end_ms == 0:
self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000'
else:
self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+str(self.end_time).split('.')[1][:3]
self.source_text = segment['text'].lstrip()
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
self.translation = ""
elif isinstance(args[0], list):
self.source_text = args[0][2]
self.duration = args[0][1]
self.start_time_str = self.duration.split(" --> ")[0]
self.end_time_str = self.duration.split(" --> ")[1]
# parse the time to float
self.start_ms = int(self.start_time_str.split(',')[1])/10
self.end_ms = int(self.end_time_str.split(',')[1])/10
start_list = self.start_time_str.split(',')[0].split(':')
self.start = int(start_list[0])*3600 + int(start_list[1])*60 + int(start_list[2]) + self.start_ms/100
end_list = self.end_time_str.split(',')[0].split(':')
self.end = int(end_list[0])*3600 + int(end_list[1])*60 + int(end_list[2]) + self.end_ms/100
self.translation = ""
def merge_seg(self, seg):
self.source_text += seg.source_text
self.translation += seg.translation
self.end_time_str = seg.end_time_str
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
pass
def __str__(self) -> str:
return f'{self.duration}\n{self.source_text}\n\n'
def get_trans_str(self) -> str:
return f'{self.duration}\n{self.translation}\n\n'
def get_bilingual_str(self) -> str:
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
class SRT_script():
def __init__(self, segments) -> None:
self.segments = []
for seg in segments:
srt_seg = SRT_segment(seg)
self.segments.append(srt_seg)
@classmethod
def parse_from_srt_file(cls, path:str):
with open(path, 'r', encoding="utf-8") as f:
script_lines = f.read().splitlines()
segments = []
for i in range(len(script_lines)):
if i % 4 == 0:
segments.append(list(script_lines[i:i+4]))
return cls(segments)
def merge_segs(self, idx_list) -> SRT_segment:
final_seg = self.segments[idx_list[0]]
if len(idx_list) == 1:
return final_seg
for idx in range(1, len(idx_list)):
final_seg.merge_seg(self.segments[idx_list[idx]])
return final_seg
def form_whole_sentence(self):
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
sentence = []
for i, seg in enumerate(self.segments):
if seg.source_text[-1] == '.':
sentence.append(i)
merge_list.append(sentence)
sentence = []
else:
sentence.append(i)
segments = []
for idx_list in merge_list:
segments.append(self.merge_segs(idx_list))
self.segments = segments # need memory release?
def set_translation(self, translate:str, id_range:tuple, model):
start_seg_id = id_range[0]
end_seg_id = id_range[1]
def inner_func(input_str):
response = openai.ChatCompletion.create(
model=model,
messages = [
{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
{"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
{"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
],
temperature=0.7
)
return response['choices'][0]['message']['content'].strip()
lines = translate.split('\n\n')
if len(lines) < (end_seg_id - start_seg_id + 1):
count = 0
while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
count += 1
print("Solving Unmatched Lines|iteration {}".format(count))
input_str = "\n"
#initialize GPT input
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
#Append to prompt string
#Adds sentence index let GPT keep track of sentence breaks
input_str += translate
#append translate to prompt
flag = True
while flag:
flag = False
try:
translate = inner_func(input_str)
except Exception as e:
print("An error has occurred during solving unmatched lines:",e)
print("Retrying...")
flag = True
lines = translate.split('\n\n')
if len(lines) < (end_seg_id - start_seg_id + 1):
print("Failed Solving unmatched lines, Manually parse needed")
print(lines)
#print(id_range)
#for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# print(seg.source_text)
#print(translate)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# naive way to due with merge translation problem
# TODO: need a smarter solution
if i < len(lines):
if "(Note:" in lines[i]: # to avoid note
lines.remove(lines[i])
if i == len(lines) - 1:
break
try:
seg.translation = lines[i].split(":" or ": ")[1]
except:
seg.translation = lines[i]
#print(lines[i])
pass
def split_seg(self, seg, threshold=500):
# TODO: evenly split seg to 2 parts and add new seg into self.segments
source_text = seg.source_text
translation = seg.translation
src_commas = [m.start() for m in re.finditer(',', source_text)]
trans_commas = [m.start() for m in re.finditer(',', translation)]
if len(src_commas) != 0:
src_split_idx = src_commas[len(src_commas)//2] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2 - 1]
else:
src_space = [m.start() for m in re.finditer(' ', source_text)]
src_split_idx = src_space[len(src_space)//2] if len(src_space) % 2 == 1 else src_space[len(src_space)//2 - 1]
if len(trans_commas) != 0:
trans_split_idx = trans_commas[len(trans_commas)//2] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2 - 1]
else:
trans_split_idx = len(translation)//2
src_seg1 = source_text[:src_split_idx]
src_seg2 = source_text[src_split_idx:]
trans_seg1 = translation[:trans_split_idx]
trans_seg2 = translation[trans_split_idx:]
start_seg1 = seg.start
end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2
end_seg2 = seg.end
seg1_dict = {}
seg1_dict['text'] = src_seg1
seg1_dict['start'] = start_seg1
seg1_dict['end'] = end_seg1
seg1 = SRT_segment(seg1_dict)
seg1.translation = trans_seg1
seg2_dict = {}
seg2_dict['text'] = src_seg2
seg2_dict['start'] = start_seg2
seg2_dict['end'] = end_seg2
seg2 = SRT_segment(seg2_dict)
seg2.translation = trans_seg2
result_list = []
if len(seg1.translation) > threshold:
result_list += self.split_seg(seg1, threshold)
else:
result_list.append(seg1)
if len(seg2.translation) > threshold:
result_list += self.split_seg(seg2, threshold)
else:
result_list.append(seg2)
return result_list
def check_len_and_split(self, threshold=30000):
# TODO: if sentence length >= threshold, split this segments to two
segments = []
for seg in self.segments:
if len(seg.translation) > threshold:
seg_list = self.split_seg(seg, threshold)
segments += seg_list
else:
segments.append(seg)
self.segments = segments
pass
def check_len_and_split_range(self, range, threshold=30000):
# TODO: if sentence length >= threshold, split this segments to two
start_seg_id = range[0]
end_seg_id = range[1]
segments = []
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
if len(seg.translation) > threshold:
seg_list = self.split_seg(seg, threshold)
segments += seg_list
else:
segments.append(seg)
self.segments[start_seg_id-1:end_seg_id] = segments
return len(segments)
def get_source_only(self):
# return a string with pure source text
result = ""
for i, seg in enumerate(self.segments):
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
return result
def reform_src_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += str(seg)
return result
def reform_trans_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_trans_str()
return result
def form_bilingual_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_bilingual_str()
return result
def write_srt_file_src(self, path:str):
# write srt file to path
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_src_str())
pass
def write_srt_file_translate(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_trans_str())
pass
def write_srt_file_bilingual(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.form_bilingual_str())
pass
def correct_with_force_term(self):
## force term correction
# TODO: shortcut translation i.e. VA, ob
# TODO: variety of translation
# load term dictionary
with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
# change term
for seg in self.segments:
ready_words = seg.source_text.split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
[real_word, pos] = self.get_real_word(word)
if real_word in term_enzh_dict:
new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
else:
new_word = word
ready_words[i] = new_word
seg.source_text = " ".join(ready_words)
pass
def spell_check_term(self):
## known bug: I've will be replaced because i've is not in the dict
import enchant
dict = enchant.Dict('en_US')
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
for seg in self.segments:
ready_words = seg.source_text.split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
[real_word, pos] = self.get_real_word(word)
if not dict.check(word[:pos]):
suggest = term_spellDict.suggest(real_word)
if suggest: # relax spell check
new_word = word.replace(word[:pos],suggest[0])
else:
new_word = word
ready_words[i] = new_word
seg.source_text = " ".join(ready_words)
pass
def spell_correction(self, word:str, arg:int):
try:
arg in [0,1]
except ValueError:
print('only 0 or 1 for argument')
def uncover(word:str):
if word[-2:] == ".\n":
real_word = word[:-2].lower()
n = -2
elif word[-1:] in [".", "\n", ",", "!", "?"]:
real_word = word[:-1].lower()
n = -1
else:
real_word = word.lower()
n = 0
return real_word, len(word)+n
real_word = uncover(word)[0]
pos = uncover(word)[1]
new_word = word
if arg == 0: # term translate mode
with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
if real_word in term_enzh_dict:
new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
elif arg == 1: # term spell check mode
import enchant
dict = enchant.Dict('en_US')
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
if not dict.check(real_word):
if term_spellDict.suggest(real_word): # relax spell check
new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0])
return new_word
def get_real_word(self, word:str):
if word[-2:] == ".\n":
real_word = word[:-2].lower()
n = -2
elif word[-1:] in [".", "\n", ",", "!", "?"]:
real_word = word[:-1].lower()
n = -1
else:
real_word = word.lower()
n = 0
return real_word, len(word)+n
def realtime_write_srt(self,path,range,length,idx):
start_seg_id = range[0]
end_seg_id = range[1]
with open(path, "a", encoding='utf-8') as f:
# for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id+length]):
# f.write(f'{i+idx}\n')
# f.write(seg.get_trans_str())
for i, seg in enumerate(self.segments):
if i<range[0]-1: continue
if i>=range[1]:break
f.write(f'{i+idx}\n')
f.write(seg.get_trans_str())
pass
def realtime_bilingual_write_srt(self,path,range,length,idx):
start_seg_id = range[0]
end_seg_id = range[1]
with open(path, "a", encoding='utf-8') as f:
for i, seg in enumerate(self.segments):
if i<range[0]-1: continue
if i>=range[1]:break
f.write(f'{i+idx}\n')
f.write(seg.get_bilingual_str())
pass