srt_check / app.py
sam-ezai's picture
Upload app.py
d7cb368 verified
import io
import sys
import gradio as gr
import srt
import jiwer
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from datetime import timedelta
@dataclass_json
@dataclass
class ZHTW_Sub:
start: timedelta
end: timedelta
zh: str
tw: str
def read_srt(p):
with open(p) as f:
subs = list(srt.parse(f.read()))
return subs
def merge_sub(subs):
i = 1
while i < len(subs):
ps = subs[i-1]
s = subs[i]
if ps.end != s.start:
i += 1
continue
ps.end = s.end
ps.zh += f" {s.zh}"
ps.tw += f" {s.tw}"
subs.pop(i)
return subs
def merge_sub2(subs, delta):
i = 1
while i < len(subs):
ps = subs[i-1]
s = subs[i]
if s.start - ps.end > delta:
i += 1
continue
ps.end = s.end
ps.zh += f" {s.zh}"
ps.tw += f" {s.tw}"
subs.pop(i)
return subs
def filter_sub(subs):
buffer = io.StringIO()
stdout_bak = sys.stdout
sys.stdout = buffer # Redirect print to buffer
new_subs = []
carry_next = False
for s in subs:
content = s.content
if '#' in s.content:
print('註:標記', s.start, s.end, s.content)
continue
if '\n' in content:
print('修:分行', '\\n', s.start, content)
carry_next = True
continue #?
else:
content = [content]
if len(content) != 1:
print('註:多行', '\\n', s.start, content)
print(s.start, s.end)
tw_all, zh_all = [], []
for cnt in content:
if '|' in cnt:
if len(cnt.split('|')) %2 != 0:
print('修:多槓', cnt.split('|'))
continue
tw, zh = cnt.split('|')
tw, zh = (t.strip() for t in [tw, zh])
else:
sp = cnt.split()
if len(sp) %2!=0:
print('修:不均', s.start, s.end, sp)
continue
else:
mid = len(sp)//2
tw, zh = sp[:mid], sp[mid:]
tw, zh = (' '.join(t) for t in [tw, zh])
if jiwer.cer(tw, zh) > 1:
print('註:差距', s.start, s.end, 'tw:', tw, 'zh:', zh)
tw_all.append(tw)
zh_all.append(zh)
if carry_next:
new_subs[-1].zh += f" {zh}"
new_subs[-1].tw += f" {tw}"
new_subs[-1].end = s.end
carry_next = False
else:
new_sub = ZHTW_Sub(s.start, s.end, zh, tw)
new_subs.append(new_sub)
sys.stdout = stdout_bak
return new_subs, buffer
def update_yield():
buffer = []
def update_print(inp):
buffer.append(str(inp))
return '\n'.join(buffer)
return update_print
def parse_srt(file):
if file is None:
return "No file uploaded."
upd = update_yield()
yield upd(file.name)
subs = read_srt(file.name)
yield upd(len(subs))
new_subs, logs = filter_sub(subs)
yield upd(logs.getvalue())
yield upd(len(new_subs))
new_subs = merge_sub(new_subs)
yield upd(len(new_subs))
# ep_name = file.name.replace('-dedup', '')
# ep_name = ep_name.replace('.fix', '')
total_dur = 0
for i, it in enumerate(new_subs):
if (it.end-it.start).total_seconds() > 30:
yield upd(i)
yield upd(it.end.total_seconds(), (it.end-it.start).total_seconds(), it.tw)
total_dur += (it.end-it.start).total_seconds()
yield upd("可用時長 "+str(timedelta(seconds=int(total_dur))))
with gr.Blocks() as demo:
gr.Markdown("## SRT File Validator")
with gr.Column():
file_input = gr.File(label="Upload .srt File", file_types=[".srt"])
output_log = gr.Textbox(label="Parsing Log", lines=10, max_lines=120)
file_input.change(fn=parse_srt, inputs=file_input, outputs=output_log)
demo.launch()