Spaces:
Sleeping
Sleeping
from datetime import timedelta | |
import os | |
import whisper | |
from csv import reader | |
import re | |
import openai | |
class SRT_segment(object): | |
def __init__(self, *args) -> None: | |
if isinstance(args[0], dict): | |
segment = args[0] | |
start_ms = int((segment['start']*100)%100*10) | |
end_ms = int((segment['end']*100)%100*10) | |
start_time = str(timedelta(seconds=int(segment['start']), milliseconds=start_ms)) | |
end_time = str(timedelta(seconds=int(segment['end']), milliseconds=end_ms)) | |
if start_ms == 0: | |
self.start_time_str = str(0)+start_time.split('.')[0]+',000' | |
else: | |
self.start_time_str = str(0)+start_time.split('.')[0]+','+start_time.split('.')[1][:3] | |
if end_ms == 0: | |
self.end_time_str = str(0)+end_time.split('.')[0]+',000' | |
else: | |
self.end_time_str = str(0)+end_time.split('.')[0]+','+end_time.split('.')[1][:3] | |
self.source_text = segment['text'] | |
self.duration = f"{self.start_time_str} --> {self.end_time_str}" | |
self.translation = "" | |
elif isinstance(args[0], list): | |
self.source_text = args[0][2] | |
self.duration = args[0][1] | |
self.start_time_str = self.duration.split(" --> ")[0] | |
self.end_time_str = self.duration.split(" --> ")[1] | |
self.translation = "" | |
def merge_seg(self, seg): | |
self.source_text += seg.source_text | |
self.translation += seg.translation | |
self.end_time_str = seg.end_time_str | |
self.duration = f"{self.start_time_str} --> {self.end_time_str}" | |
pass | |
def __str__(self) -> str: | |
return f'{self.duration}\n{self.source_text}\n\n' | |
def get_trans_str(self) -> str: | |
return f'{self.duration}\n{self.translation}\n\n' | |
def get_bilingual_str(self) -> str: | |
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' | |
class SRT_script(): | |
def __init__(self, segments) -> None: | |
self.segments = [] | |
for seg in segments: | |
srt_seg = SRT_segment(seg) | |
self.segments.append(srt_seg) | |
def parse_from_srt_file(cls, path:str): | |
with open(path, 'r', encoding="utf-8") as f: | |
script_lines = f.read().splitlines() | |
segments = [] | |
for i in range(len(script_lines)): | |
if i % 4 == 0: | |
segments.append(list(script_lines[i:i+4])) | |
return cls(segments) | |
def merge_segs(self, idx_list) -> SRT_segment: | |
final_seg = self.segments[idx_list[0]] | |
if len(idx_list) == 1: | |
return final_seg | |
for idx in range(1, len(idx_list)): | |
final_seg.merge_seg(self.segments[idx_list[idx]]) | |
return final_seg | |
def form_whole_sentence(self): | |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] | |
sentence = [] | |
for i, seg in enumerate(self.segments): | |
if seg.source_text[-1] == '.': | |
sentence.append(i) | |
merge_list.append(sentence) | |
sentence = [] | |
else: | |
sentence.append(i) | |
segments = [] | |
for idx_list in merge_list: | |
segments.append(self.merge_segs(idx_list)) | |
self.segments = segments # need memory release? | |
def set_translation(self, translate:str, id_range:tuple): | |
start_seg_id = id_range[0] | |
end_seg_id = id_range[1] | |
lines = translate.split('\n\n') | |
if len(lines) != (end_seg_id - start_seg_id + 1): | |
input_str = "\n"; | |
#initialize GPT input | |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): | |
input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n' | |
#Append to prompt string | |
#Adds sentence index let GPT keep track of sentence breaks | |
input_str += translate | |
#append translate to prompt | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."}, | |
{"role": "system", "content": "You are provided with a translated Chinese transcript, you need to reformat the Chinese sentence to match the meaning and sentence number as the English transcript"}, | |
{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."}, | |
{"role": "user", "content": 'Reformat the Chinese with the English transcript given: "{}"'.format(input_str)} | |
], | |
temperature=0.15 | |
) | |
translate = response['choices'][0]['text'].strip() | |
#print(id_range) | |
#for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): | |
# print(seg.source_text) | |
#print(translate) | |
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): | |
# naive way to due with merge translation problem | |
# TODO: need a smarter solution | |
if i < len(lines): | |
if "(Note:" in lines[i]: # to avoid note | |
lines.remove(lines[i]) | |
if i == len(lines) - 1: | |
break | |
try: | |
seg.translation = lines[i].split(":" or ": ")[1] | |
except: | |
seg.translation = lines[i] | |
pass | |
def split_seg(self, seg_id): | |
# TODO: evenly split seg to 2 parts and add new seg into self.segments | |
pass | |
def check_len_and_split(self, threshold): | |
# TODO: if sentence length >= threshold, split this segments to two | |
pass | |
def get_source_only(self): | |
# return a string with pure source text | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n' | |
return result | |
def reform_src_str(self): | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{i+1}\n' | |
result += str(seg) | |
return result | |
def reform_trans_str(self): | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{i+1}\n' | |
result += seg.get_trans_str() | |
return result | |
def form_bilingual_str(self): | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{i+1}\n' | |
result += seg.get_bilingual_str() | |
return result | |
def write_srt_file_src(self, path:str): | |
# write srt file to path | |
with open(path, "w", encoding='utf-8') as f: | |
f.write(self.reform_src_str()) | |
pass | |
def write_srt_file_translate(self, path:str): | |
with open(path, "w", encoding='utf-8') as f: | |
f.write(self.reform_trans_str()) | |
pass | |
def write_srt_file_bilingual(self, path:str): | |
with open(path, "w", encoding='utf-8') as f: | |
f.write(self.form_bilingual_str()) | |
pass | |
def correct_with_force_term(self): | |
## force term correction | |
# TODO: shortcut translation i.e. VA, ob | |
# TODO: variety of translation | |
# load term dictionary | |
with open("finetune_data/dict.csv",'r', encoding='utf-8') as f: | |
csv_reader = reader(f) | |
term_dict = {rows[0]:rows[1] for rows in csv_reader} | |
# change term | |
for seg in self.segments: | |
ready_words = re.sub('\n', '\n ', seg.source_text).split(" ") | |
for i in range(len(ready_words)): | |
word = ready_words[i] | |
if word[-2:] == ".\n" : | |
if word[:-2].lower() in term_dict : | |
new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' ' | |
ready_words[i] = new_word | |
else: | |
ready_words[i] = word + ' ' | |
elif word.lower() in term_dict : | |
new_word = word.replace(word,term_dict.get(word.lower())) + ' ' | |
ready_words[i] = new_word | |
else : | |
ready_words[i]= word + ' ' | |
seg.source_text = re.sub('\n ', '\n', "".join(ready_words)) | |
print(self) | |
pass | |