ViDove / SRT.py
DWizard
rewrite forceTerm replacement
66e606c
raw
history blame
4.7 kB
from datetime import timedelta
import os
import whisper
from csv import reader
import re
class SRT_segment(object):
def __init__(self, *args) -> None:
if isinstance(args[0], dict):
segment = args[0]
self.start_time_str = str(0)+str(timedelta(seconds=int(segment['start'])))+',000'
self.end_time_str = str(0)+str(timedelta(seconds=int(segment['end'])))+',000'
self.segment_id = segment['id']+1
self.source_text = segment['text']
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
self.translation = ""
elif isinstance(args[0], list):
self.segment_id = args[0][0]
self.source_text = args[0][2]
self.duration = args[0][1]
self.start_time_str = self.duration.split("-->")[0]
self.end_time_str = self.duration.split("-->")[1]
self.translation = ""
def __str__(self) -> str:
return f'{self.segment_id}\n{self.duration}\n{self.source_text}\n\n'
def get_trans_str(self) -> str:
return f'{self.segment_id}\n{self.duration}\n{self.translation}\n\n'
def get_bilingual_str(self) -> str:
return f'{self.segment_id}\n{self.duration}\n{self.source_text}\n{self.translation}\n\n'
class SRT_script():
def __init__(self, segments) -> None:
self.segments = []
for seg in segments:
srt_seg = SRT_segment(seg)
self.segments.append(srt_seg)
@classmethod
def parse_from_srt_file(cls, path:str):
with open(path, 'r', encoding="utf-8") as f:
script_lines = f.read().splitlines()
segments = []
for i in range(len(script_lines)):
if i % 4 == 0:
segments.append(list(script_lines[i:i+4]))
return cls(segments)
def set_translation(self, translate:str, id_range:tuple):
start_seg_id = id_range[0]
end_seg_id = id_range[1]
lines = translate.split('\n\n')
print(id_range)
print(translate)
# print(len(translate))
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
seg.translation = lines[i]
pass
def get_source_only(self):
# return a string with pure source text
result = ""
for seg in self.segments:
result+=f'{seg.source_text}\n\n'
return result
def reform_src_str(self):
result = ""
for seg in self.segments:
result += str(seg)
return result
def reform_trans_str(self):
result = ""
for seg in self.segments:
result += seg.get_trans_str()
return result
def form_bilingual_str(self):
result = ""
for seg in self.segments:
result += seg.get_bilingual_str()
return result
def write_srt_file_src(self, path:str):
# write srt file to path
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_src_str())
pass
def write_srt_file_translate(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_trans_str())
pass
def write_srt_file_bilingual(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.form_bilingual_str())
pass
def correct_with_force_term(self):
## force term correction
# TODO: shortcut translation i.e. VA, ob
# TODO: variety of translation
# load term dictionary
with open("finetune_data/dict.csv",'r', encoding='utf-8') as f:
csv_reader = reader(f)
term_dict = {rows[0]:rows[1] for rows in csv_reader}
# change term
for seg in self.segments:
ready_words = re.sub('\n', '\n ', seg.source_text).split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
if word[-2:] == ".\n" :
if word[:-2].lower() in term_dict :
new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' '
ready_words[i] = new_word
else:
ready_words[i] = word + ' '
elif word.lower() in term_dict :
new_word = word.replace(word,term_dict.get(word.lower())) + ' '
ready_words[i] = new_word
else :
ready_words[i]= word + ' '
seg.source_text = re.sub('\n ', '\n', "".join(ready_words))
print(self)
pass