ViDove / SRT.py
worldqwq
Use GPT prompt to solve sentence merging issue
b39d769
raw
history blame
8.68 kB
from datetime import timedelta
import os
import whisper
from csv import reader
import re
import openai
class SRT_segment(object):
def __init__(self, *args) -> None:
if isinstance(args[0], dict):
segment = args[0]
start_ms = int((segment['start']*100)%100*10)
end_ms = int((segment['end']*100)%100*10)
start_time = str(timedelta(seconds=int(segment['start']), milliseconds=start_ms))
end_time = str(timedelta(seconds=int(segment['end']), milliseconds=end_ms))
if start_ms == 0:
self.start_time_str = str(0)+start_time.split('.')[0]+',000'
else:
self.start_time_str = str(0)+start_time.split('.')[0]+','+start_time.split('.')[1][:3]
if end_ms == 0:
self.end_time_str = str(0)+end_time.split('.')[0]+',000'
else:
self.end_time_str = str(0)+end_time.split('.')[0]+','+end_time.split('.')[1][:3]
self.source_text = segment['text']
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
self.translation = ""
elif isinstance(args[0], list):
self.source_text = args[0][2]
self.duration = args[0][1]
self.start_time_str = self.duration.split(" --> ")[0]
self.end_time_str = self.duration.split(" --> ")[1]
self.translation = ""
def merge_seg(self, seg):
self.source_text += seg.source_text
self.translation += seg.translation
self.end_time_str = seg.end_time_str
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
pass
def __str__(self) -> str:
return f'{self.duration}\n{self.source_text}\n\n'
def get_trans_str(self) -> str:
return f'{self.duration}\n{self.translation}\n\n'
def get_bilingual_str(self) -> str:
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
class SRT_script():
def __init__(self, segments) -> None:
self.segments = []
for seg in segments:
srt_seg = SRT_segment(seg)
self.segments.append(srt_seg)
@classmethod
def parse_from_srt_file(cls, path:str):
with open(path, 'r', encoding="utf-8") as f:
script_lines = f.read().splitlines()
segments = []
for i in range(len(script_lines)):
if i % 4 == 0:
segments.append(list(script_lines[i:i+4]))
return cls(segments)
def merge_segs(self, idx_list) -> SRT_segment:
final_seg = self.segments[idx_list[0]]
if len(idx_list) == 1:
return final_seg
for idx in range(1, len(idx_list)):
final_seg.merge_seg(self.segments[idx_list[idx]])
return final_seg
def form_whole_sentence(self):
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
sentence = []
for i, seg in enumerate(self.segments):
if seg.source_text[-1] == '.':
sentence.append(i)
merge_list.append(sentence)
sentence = []
else:
sentence.append(i)
segments = []
for idx_list in merge_list:
segments.append(self.merge_segs(idx_list))
self.segments = segments # need memory release?
def set_translation(self, translate:str, id_range:tuple):
start_seg_id = id_range[0]
end_seg_id = id_range[1]
lines = translate.split('\n\n')
if len(lines) != (end_seg_id - start_seg_id + 1):
input_str = "\n";
#initialize GPT input
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
#Append to prompt string
#Adds sentence index let GPT keep track of sentence breaks
input_str += translate
#append translate to prompt
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
{"role": "system", "content": "You are provided with a translated Chinese transcript, you need to reformat the Chinese sentence to match the meaning and sentence number as the English transcript"},
{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
{"role": "user", "content": 'Reformat the Chinese with the English transcript given: "{}"'.format(input_str)}
],
temperature=0.15
)
translate = response['choices'][0]['text'].strip()
#print(id_range)
#for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# print(seg.source_text)
#print(translate)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# naive way to due with merge translation problem
# TODO: need a smarter solution
if i < len(lines):
if "(Note:" in lines[i]: # to avoid note
lines.remove(lines[i])
if i == len(lines) - 1:
break
try:
seg.translation = lines[i].split(":" or ": ")[1]
except:
seg.translation = lines[i]
pass
def split_seg(self, seg_id):
# TODO: evenly split seg to 2 parts and add new seg into self.segments
pass
def check_len_and_split(self, threshold):
# TODO: if sentence length >= threshold, split this segments to two
pass
def get_source_only(self):
# return a string with pure source text
result = ""
for i, seg in enumerate(self.segments):
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
return result
def reform_src_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += str(seg)
return result
def reform_trans_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_trans_str()
return result
def form_bilingual_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_bilingual_str()
return result
def write_srt_file_src(self, path:str):
# write srt file to path
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_src_str())
pass
def write_srt_file_translate(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_trans_str())
pass
def write_srt_file_bilingual(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.form_bilingual_str())
pass
def correct_with_force_term(self):
## force term correction
# TODO: shortcut translation i.e. VA, ob
# TODO: variety of translation
# load term dictionary
with open("finetune_data/dict.csv",'r', encoding='utf-8') as f:
csv_reader = reader(f)
term_dict = {rows[0]:rows[1] for rows in csv_reader}
# change term
for seg in self.segments:
ready_words = re.sub('\n', '\n ', seg.source_text).split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
if word[-2:] == ".\n" :
if word[:-2].lower() in term_dict :
new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' '
ready_words[i] = new_word
else:
ready_words[i] = word + ' '
elif word.lower() in term_dict :
new_word = word.replace(word,term_dict.get(word.lower())) + ' '
ready_words[i] = new_word
else :
ready_words[i]= word + ' '
seg.source_text = re.sub('\n ', '\n', "".join(ready_words))
print(self)
pass