Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,9 @@ import spaces
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline
|
5 |
-
from transformers.pipelines.audio_utils import ffmpeg_read
|
6 |
|
7 |
import tempfile
|
8 |
import os
|
9 |
-
import time
|
10 |
from datetime import timedelta
|
11 |
|
12 |
# ===== 配置 =====
|
@@ -23,44 +21,103 @@ pipe = pipeline(
|
|
23 |
chunk_length_s=30,
|
24 |
device=device,
|
25 |
torch_dtype=dtype,
|
26 |
-
return_timestamps=
|
27 |
)
|
28 |
|
29 |
# ===== 工具函数:时间戳/SRT =====
|
30 |
-
def _srt_timestamp(seconds
|
31 |
-
"""秒 -> SRT 时间戳 00:00:00,000。None
|
32 |
-
if seconds is None:
|
33 |
seconds = 0.0
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
continue
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
lines.append("")
|
63 |
-
idx += 1
|
64 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
65 |
|
66 |
# ===== 上传音频 -> SRT 导出 =====
|
@@ -76,20 +133,21 @@ def transcribe_file_to_srt(audio_path: str, task: str):
|
|
76 |
pass
|
77 |
|
78 |
result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
|
79 |
-
text = result.get("text", "")
|
80 |
chunks = result.get("chunks") or []
|
81 |
-
srt_str = chunks_to_srt(chunks)
|
82 |
|
83 |
-
|
84 |
-
|
85 |
|
|
|
86 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
87 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
88 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
89 |
with open(srt_path, "w", encoding="utf-8") as f:
|
90 |
f.write(srt_str)
|
91 |
|
92 |
-
|
|
|
93 |
|
94 |
# ===== Gradio 界面 =====
|
95 |
demo = gr.Interface(
|
@@ -99,7 +157,7 @@ demo = gr.Interface(
|
|
99 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
100 |
],
|
101 |
outputs=[
|
102 |
-
gr.Textbox(label="Transcript (
|
103 |
gr.File(label="Download SRT"),
|
104 |
],
|
105 |
title="Upload Audio → SRT Subtitle",
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline
|
|
|
5 |
|
6 |
import tempfile
|
7 |
import os
|
|
|
8 |
from datetime import timedelta
|
9 |
|
10 |
# ===== 配置 =====
|
|
|
21 |
chunk_length_s=30,
|
22 |
device=device,
|
23 |
torch_dtype=dtype,
|
24 |
+
return_timestamps="word", # 关键:逐词时间戳,便于细分
|
25 |
)
|
26 |
|
27 |
# ===== 工具函数:时间戳/SRT =====
|
28 |
+
def _srt_timestamp(seconds):
|
29 |
+
"""秒 -> SRT 时间戳 00:00:00,000。None/负数时归零。"""
|
30 |
+
if seconds is None or seconds < 0:
|
31 |
seconds = 0.0
|
32 |
+
ms = int(float(seconds) * 1000 + 0.5)
|
33 |
+
h, ms = divmod(ms, 3600000)
|
34 |
+
m, ms = divmod(ms, 60000)
|
35 |
+
s, ms = divmod(ms, 1000)
|
36 |
+
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
37 |
+
|
38 |
+
def chunks_to_srt(chunks, text_fallback="", max_seg_dur=6.0, max_seg_chars=42):
|
39 |
+
"""
|
40 |
+
用逐词时间戳把长 chunk 细分成更短的 SRT 行:
|
41 |
+
- 每行最长持续 max_seg_dur 秒
|
42 |
+
- 或字符数约 max_seg_chars
|
43 |
+
- 遇到句末标点(。!?.!?)优先断句
|
44 |
+
"""
|
45 |
+
segs = []
|
46 |
+
cur_words = []
|
47 |
+
cur_start = None
|
48 |
+
cur_len = 0
|
49 |
+
|
50 |
+
def flush_seg():
|
51 |
+
nonlocal cur_words, cur_start, cur_len
|
52 |
+
if not cur_words:
|
53 |
+
return
|
54 |
+
# 兼容多种时间戳字段
|
55 |
+
st = cur_start if cur_start is not None else cur_words[0].get("start", 0.0)
|
56 |
+
en = cur_words[-1].get("end", cur_words[-1].get("timestamp", [0.0, 0.0])[-1] if isinstance(cur_words[-1].get("timestamp"), (list, tuple)) else 0.0)
|
57 |
+
if isinstance(st, (list, tuple)): st = st[0]
|
58 |
+
if isinstance(en, (list, tuple)): en = en[-1]
|
59 |
+
text = "".join(w.get("word", "").strip() for w in cur_words).strip()
|
60 |
+
if text:
|
61 |
+
segs.append((float(st or 0.0), float(en or 0.0), text))
|
62 |
+
cur_words = []
|
63 |
+
cur_start = None
|
64 |
+
cur_len = 0
|
65 |
+
|
66 |
+
def maybe_flush(force=False, strong_punct=False):
|
67 |
+
if not cur_words:
|
68 |
+
return
|
69 |
+
st = cur_start if cur_start is not None else cur_words[0].get("start", 0.0)
|
70 |
+
en = cur_words[-1].get("end", cur_words[-1].get("timestamp", [0.0, 0.0])[-1] if isinstance(cur_words[-1].get("timestamp"), (list, tuple)) else 0.0)
|
71 |
+
if isinstance(st, (list, tuple)): st = st[0]
|
72 |
+
if isinstance(en, (list, tuple)): en = en[-1]
|
73 |
+
dur = float((en or 0.0) - (st or 0.0))
|
74 |
+
if force or strong_punct or dur >= max_seg_dur or cur_len >= max_seg_chars:
|
75 |
+
flush_seg()
|
76 |
+
|
77 |
+
# 汇总所有词
|
78 |
+
all_words = []
|
79 |
+
for ch in chunks or []:
|
80 |
+
words = ch.get("words") or []
|
81 |
+
if not words and ch.get("text"):
|
82 |
+
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
|
83 |
+
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
84 |
+
all_words.append({"word": ch["text"], "start": ts[0], "end": ts[1]})
|
85 |
+
else:
|
86 |
+
all_words.append({"word": ch["text"], "start": 0.0, "end": 2.0})
|
87 |
continue
|
88 |
+
for w in words:
|
89 |
+
token = (w.get("word") or "").replace("\n", " ")
|
90 |
+
start = w.get("start")
|
91 |
+
end = w.get("end")
|
92 |
+
if (start is None or end is None) and isinstance(w.get("timestamp"), (list, tuple)) and len(w["timestamp"]) == 2:
|
93 |
+
start, end = w["timestamp"]
|
94 |
+
all_words.append({"word": token, "start": start, "end": end})
|
95 |
+
|
96 |
+
# 若依旧拿不到逐词,回退整段文本
|
97 |
+
if not all_words and text_fallback.strip():
|
98 |
+
all_words = [{"word": text_fallback.strip(), "start": 0.0, "end": max_seg_dur}]
|
99 |
+
|
100 |
+
# 按规则切分
|
101 |
+
for w in all_words:
|
102 |
+
token = w.get("word", "")
|
103 |
+
if not token:
|
104 |
+
continue
|
105 |
+
if cur_start is None:
|
106 |
+
cur_start = w.get("start", 0.0)
|
107 |
+
cur_words.append(w)
|
108 |
+
cur_len += len(token)
|
109 |
+
strong = token.endswith(("。", "!", "?", ".", "!", "?"))
|
110 |
+
maybe_flush(force=False, strong_punct=strong)
|
111 |
+
|
112 |
+
maybe_flush(force=True)
|
113 |
+
|
114 |
+
# 生成 SRT
|
115 |
+
lines = []
|
116 |
+
for i, (st, en, txt) in enumerate(segs, 1):
|
117 |
+
lines.append(str(i))
|
118 |
+
lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
|
119 |
+
lines.append(txt)
|
120 |
lines.append("")
|
|
|
121 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
122 |
|
123 |
# ===== 上传音频 -> SRT 导出 =====
|
|
|
133 |
pass
|
134 |
|
135 |
result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
|
136 |
+
text = result.get("text", "") or ""
|
137 |
chunks = result.get("chunks") or []
|
|
|
138 |
|
139 |
+
# 转 SRT(预览即为 SRT)
|
140 |
+
srt_str = chunks_to_srt(chunks, text_fallback=text)
|
141 |
|
142 |
+
# 写入临时文件供下载
|
143 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
144 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
145 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
146 |
with open(srt_path, "w", encoding="utf-8") as f:
|
147 |
f.write(srt_str)
|
148 |
|
149 |
+
# 第一个输出显示 SRT 字符串,第二个输出提供下载
|
150 |
+
return srt_str, srt_path
|
151 |
|
152 |
# ===== Gradio 界面 =====
|
153 |
demo = gr.Interface(
|
|
|
157 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
158 |
],
|
159 |
outputs=[
|
160 |
+
gr.Textbox(label="Transcript (SRT Preview)", lines=18),
|
161 |
gr.File(label="Download SRT"),
|
162 |
],
|
163 |
title="Upload Audio → SRT Subtitle",
|