|
import io |
|
import sys |
|
import gradio as gr |
|
import srt |
|
import jiwer |
|
|
|
from dataclasses import dataclass |
|
from dataclasses_json import dataclass_json |
|
from datetime import timedelta |
|
|
|
|
|
@dataclass_json |
|
@dataclass |
|
class ZHTW_Sub: |
|
start: timedelta |
|
end: timedelta |
|
zh: str |
|
tw: str |
|
|
|
def read_srt(p): |
|
with open(p) as f: |
|
subs = list(srt.parse(f.read())) |
|
return subs |
|
|
|
def merge_sub(subs): |
|
i = 1 |
|
while i < len(subs): |
|
ps = subs[i-1] |
|
s = subs[i] |
|
if ps.end != s.start: |
|
i += 1 |
|
continue |
|
|
|
ps.end = s.end |
|
ps.zh += f" {s.zh}" |
|
ps.tw += f" {s.tw}" |
|
subs.pop(i) |
|
return subs |
|
|
|
def merge_sub2(subs, delta): |
|
i = 1 |
|
while i < len(subs): |
|
ps = subs[i-1] |
|
s = subs[i] |
|
if s.start - ps.end > delta: |
|
i += 1 |
|
continue |
|
|
|
ps.end = s.end |
|
ps.zh += f" {s.zh}" |
|
ps.tw += f" {s.tw}" |
|
subs.pop(i) |
|
return subs |
|
|
|
def filter_sub(subs): |
|
buffer = io.StringIO() |
|
stdout_bak = sys.stdout |
|
sys.stdout = buffer |
|
|
|
new_subs = [] |
|
carry_next = False |
|
for s in subs: |
|
content = s.content |
|
if '#' in s.content: |
|
print('註:標記', s.start, s.end, s.content) |
|
continue |
|
|
|
if '\n' in content: |
|
print('修:分行', '\\n', s.start, content) |
|
carry_next = True |
|
continue |
|
else: |
|
content = [content] |
|
|
|
if len(content) != 1: |
|
print('註:多行', '\\n', s.start, content) |
|
print(s.start, s.end) |
|
|
|
tw_all, zh_all = [], [] |
|
for cnt in content: |
|
if '|' in cnt: |
|
if len(cnt.split('|')) %2 != 0: |
|
print('修:多槓', cnt.split('|')) |
|
continue |
|
tw, zh = cnt.split('|') |
|
tw, zh = (t.strip() for t in [tw, zh]) |
|
|
|
else: |
|
sp = cnt.split() |
|
if len(sp) %2!=0: |
|
print('修:不均', s.start, s.end, sp) |
|
continue |
|
else: |
|
mid = len(sp)//2 |
|
tw, zh = sp[:mid], sp[mid:] |
|
tw, zh = (' '.join(t) for t in [tw, zh]) |
|
if jiwer.cer(tw, zh) > 1: |
|
print('註:差距', s.start, s.end, 'tw:', tw, 'zh:', zh) |
|
tw_all.append(tw) |
|
zh_all.append(zh) |
|
if carry_next: |
|
new_subs[-1].zh += f" {zh}" |
|
new_subs[-1].tw += f" {tw}" |
|
new_subs[-1].end = s.end |
|
carry_next = False |
|
else: |
|
new_sub = ZHTW_Sub(s.start, s.end, zh, tw) |
|
new_subs.append(new_sub) |
|
sys.stdout = stdout_bak |
|
return new_subs, buffer |
|
|
|
def update_yield(): |
|
buffer = [] |
|
def update_print(inp): |
|
buffer.append(str(inp)) |
|
return '\n'.join(buffer) |
|
return update_print |
|
|
|
def parse_srt(file): |
|
if file is None: |
|
return "No file uploaded." |
|
|
|
upd = update_yield() |
|
yield upd(file.name) |
|
subs = read_srt(file.name) |
|
yield upd(len(subs)) |
|
new_subs, logs = filter_sub(subs) |
|
yield upd(logs.getvalue()) |
|
yield upd(len(new_subs)) |
|
new_subs = merge_sub(new_subs) |
|
yield upd(len(new_subs)) |
|
|
|
|
|
|
|
total_dur = 0 |
|
for i, it in enumerate(new_subs): |
|
if (it.end-it.start).total_seconds() > 30: |
|
yield upd(i) |
|
yield upd(it.end.total_seconds(), (it.end-it.start).total_seconds(), it.tw) |
|
total_dur += (it.end-it.start).total_seconds() |
|
yield upd("可用時長 "+str(timedelta(seconds=int(total_dur)))) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## SRT File Validator") |
|
|
|
with gr.Column(): |
|
file_input = gr.File(label="Upload .srt File", file_types=[".srt"]) |
|
output_log = gr.Textbox(label="Parsing Log", lines=10, max_lines=120) |
|
|
|
file_input.change(fn=parse_srt, inputs=file_input, outputs=output_log) |
|
|
|
demo.launch() |
|
|