datxy commited on
Commit
caa6c38
·
verified ·
1 Parent(s): 616414f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -40
app.py CHANGED
@@ -2,11 +2,9 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline
5
- from transformers.pipelines.audio_utils import ffmpeg_read
6
 
7
  import tempfile
8
  import os
9
- import time
10
  from datetime import timedelta
11
 
12
  # ===== 配置 =====
@@ -23,44 +21,103 @@ pipe = pipeline(
23
  chunk_length_s=30,
24
  device=device,
25
  torch_dtype=dtype,
26
- return_timestamps=True,
27
  )
28
 
29
  # ===== 工具函数:时间戳/SRT =====
30
- def _srt_timestamp(seconds: float | None) -> str:
31
- """秒 -> SRT 时间戳 00:00:00,000。None 时用 0."""
32
- if seconds is None:
33
  seconds = 0.0
34
- if seconds < 0:
35
- seconds = 0.0
36
- td = timedelta(seconds=float(seconds))
37
- total_ms = int(td.total_seconds() * 1000)
38
- hours = total_ms // 3_600_000
39
- minutes = (total_ms % 3_600_000) // 60_000
40
- secs = (total_ms % 60_000) // 1000
41
- ms = total_ms % 1000
42
- return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"
43
-
44
- def chunks_to_srt(chunks: list[dict]) -> str:
45
- """将 Whisper 返回的 chunks 转为 SRT 字符串。"""
46
- lines = []
47
- idx = 1
48
- for ch in chunks:
49
- ts = ch.get("timestamp") or ch.get("timestamps")
50
- text = (ch.get("text") or "").strip()
51
- if not text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  continue
53
- if isinstance(ts, (list, tuple)) and len(ts) == 2:
54
- start, end = ts
55
- else:
56
- start, end = 0.0, 2.0
57
- start_srt = _srt_timestamp(start)
58
- end_srt = _srt_timestamp(end)
59
- lines.append(str(idx))
60
- lines.append(f"{start_srt} --> {end_srt}")
61
- lines.append(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  lines.append("")
63
- idx += 1
64
  return "\n".join(lines).strip() + ("\n" if lines else "")
65
 
66
  # ===== 上传音频 -> SRT 导出 =====
@@ -76,20 +133,21 @@ def transcribe_file_to_srt(audio_path: str, task: str):
76
  pass
77
 
78
  result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
79
- text = result.get("text", "")
80
  chunks = result.get("chunks") or []
81
- srt_str = chunks_to_srt(chunks)
82
 
83
- if not srt_str and text.strip():
84
- srt_str = "1\n00:00:00,000 --> 00:00:02,000\n" + text.strip() + "\n"
85
 
 
86
  tmpdir = tempfile.mkdtemp(prefix="srt_")
87
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
88
  srt_path = os.path.join(tmpdir, f"{base}.srt")
89
  with open(srt_path, "w", encoding="utf-8") as f:
90
  f.write(srt_str)
91
 
92
- return text, srt_path
 
93
 
94
  # ===== Gradio 界面 =====
95
  demo = gr.Interface(
@@ -99,7 +157,7 @@ demo = gr.Interface(
99
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
100
  ],
101
  outputs=[
102
- gr.Textbox(label="Transcript (preview)"),
103
  gr.File(label="Download SRT"),
104
  ],
105
  title="Upload Audio → SRT Subtitle",
 
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline
 
5
 
6
  import tempfile
7
  import os
 
8
  from datetime import timedelta
9
 
10
  # ===== 配置 =====
 
21
  chunk_length_s=30,
22
  device=device,
23
  torch_dtype=dtype,
24
+ return_timestamps="word", # 关键:逐词时间戳,便于细分
25
  )
26
 
27
  # ===== 工具函数:时间戳/SRT =====
28
+ def _srt_timestamp(seconds):
29
+ """秒 -> SRT 时间戳 00:00:00,000。None/负数时归零。"""
30
+ if seconds is None or seconds < 0:
31
  seconds = 0.0
32
+ ms = int(float(seconds) * 1000 + 0.5)
33
+ h, ms = divmod(ms, 3600000)
34
+ m, ms = divmod(ms, 60000)
35
+ s, ms = divmod(ms, 1000)
36
+ return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
37
+
38
+ def chunks_to_srt(chunks, text_fallback="", max_seg_dur=6.0, max_seg_chars=42):
39
+ """
40
+ 用逐词时间戳把长 chunk 细分成更短的 SRT 行:
41
+ - 每行最长持续 max_seg_dur 秒
42
+ - 或字符数约 max_seg_chars
43
+ - 遇到句末标点(。!?.!?)优先断句
44
+ """
45
+ segs = []
46
+ cur_words = []
47
+ cur_start = None
48
+ cur_len = 0
49
+
50
+ def flush_seg():
51
+ nonlocal cur_words, cur_start, cur_len
52
+ if not cur_words:
53
+ return
54
+ # 兼容多种时间戳字段
55
+ st = cur_start if cur_start is not None else cur_words[0].get("start", 0.0)
56
+ en = cur_words[-1].get("end", cur_words[-1].get("timestamp", [0.0, 0.0])[-1] if isinstance(cur_words[-1].get("timestamp"), (list, tuple)) else 0.0)
57
+ if isinstance(st, (list, tuple)): st = st[0]
58
+ if isinstance(en, (list, tuple)): en = en[-1]
59
+ text = "".join(w.get("word", "").strip() for w in cur_words).strip()
60
+ if text:
61
+ segs.append((float(st or 0.0), float(en or 0.0), text))
62
+ cur_words = []
63
+ cur_start = None
64
+ cur_len = 0
65
+
66
+ def maybe_flush(force=False, strong_punct=False):
67
+ if not cur_words:
68
+ return
69
+ st = cur_start if cur_start is not None else cur_words[0].get("start", 0.0)
70
+ en = cur_words[-1].get("end", cur_words[-1].get("timestamp", [0.0, 0.0])[-1] if isinstance(cur_words[-1].get("timestamp"), (list, tuple)) else 0.0)
71
+ if isinstance(st, (list, tuple)): st = st[0]
72
+ if isinstance(en, (list, tuple)): en = en[-1]
73
+ dur = float((en or 0.0) - (st or 0.0))
74
+ if force or strong_punct or dur >= max_seg_dur or cur_len >= max_seg_chars:
75
+ flush_seg()
76
+
77
+ # 汇总所有词
78
+ all_words = []
79
+ for ch in chunks or []:
80
+ words = ch.get("words") or []
81
+ if not words and ch.get("text"):
82
+ ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
83
+ if isinstance(ts, (list, tuple)) and len(ts) == 2:
84
+ all_words.append({"word": ch["text"], "start": ts[0], "end": ts[1]})
85
+ else:
86
+ all_words.append({"word": ch["text"], "start": 0.0, "end": 2.0})
87
  continue
88
+ for w in words:
89
+ token = (w.get("word") or "").replace("\n", " ")
90
+ start = w.get("start")
91
+ end = w.get("end")
92
+ if (start is None or end is None) and isinstance(w.get("timestamp"), (list, tuple)) and len(w["timestamp"]) == 2:
93
+ start, end = w["timestamp"]
94
+ all_words.append({"word": token, "start": start, "end": end})
95
+
96
+ # 若依旧拿不到逐词,回退整段文本
97
+ if not all_words and text_fallback.strip():
98
+ all_words = [{"word": text_fallback.strip(), "start": 0.0, "end": max_seg_dur}]
99
+
100
+ # 按规则切分
101
+ for w in all_words:
102
+ token = w.get("word", "")
103
+ if not token:
104
+ continue
105
+ if cur_start is None:
106
+ cur_start = w.get("start", 0.0)
107
+ cur_words.append(w)
108
+ cur_len += len(token)
109
+ strong = token.endswith(("。", "!", "?", ".", "!", "?"))
110
+ maybe_flush(force=False, strong_punct=strong)
111
+
112
+ maybe_flush(force=True)
113
+
114
+ # 生成 SRT
115
+ lines = []
116
+ for i, (st, en, txt) in enumerate(segs, 1):
117
+ lines.append(str(i))
118
+ lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
119
+ lines.append(txt)
120
  lines.append("")
 
121
  return "\n".join(lines).strip() + ("\n" if lines else "")
122
 
123
  # ===== 上传音频 -> SRT 导出 =====
 
133
  pass
134
 
135
  result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
136
+ text = result.get("text", "") or ""
137
  chunks = result.get("chunks") or []
 
138
 
139
+ # SRT(预览即为 SRT)
140
+ srt_str = chunks_to_srt(chunks, text_fallback=text)
141
 
142
+ # 写入临时文件供下载
143
  tmpdir = tempfile.mkdtemp(prefix="srt_")
144
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
145
  srt_path = os.path.join(tmpdir, f"{base}.srt")
146
  with open(srt_path, "w", encoding="utf-8") as f:
147
  f.write(srt_str)
148
 
149
+ # 第一个输出显示 SRT 字符串,第二个输出提供下载
150
+ return srt_str, srt_path
151
 
152
  # ===== Gradio 界面 =====
153
  demo = gr.Interface(
 
157
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
158
  ],
159
  outputs=[
160
+ gr.Textbox(label="Transcript (SRT Preview)", lines=18),
161
  gr.File(label="Download SRT"),
162
  ],
163
  title="Upload Audio → SRT Subtitle",