laubonghaudoi commited on
Commit
1d7163f
1 Parent(s): c5d7b1a

Inital commit

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ models/*
2
+ !models/denoiser.onnx
3
+ .venv
4
+ __pycache__
5
+ .DS_Store
6
+ *.mp3
7
+ output
8
+ .aider*
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import tempfile
3
+
4
+ import gradio as gr
5
+
6
+ from transcriber import AutoTranscriber
7
+ from utils import to_srt
8
+
9
+ # Configure logging
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13
+ force=True,
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def transcribe_audio(audio_path):
19
+ """Process audio file and return SRT content and preview text"""
20
+ try:
21
+ transcriber = AutoTranscriber(
22
+ corrector="opencc",
23
+ use_denoiser=False,
24
+ with_punct=False
25
+ )
26
+
27
+ transcribe_results = transcriber.transcribe(audio_path)
28
+
29
+ if not transcribe_results:
30
+ return None, "無字幕生成, 可能係檢測唔到語音。"
31
+
32
+ # Generate SRT text for both preview and download
33
+ srt_text = to_srt(transcribe_results)
34
+
35
+ # Create temporary file for download
36
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.srt', encoding='utf-8') as tmp:
37
+ tmp.write(srt_text)
38
+ return tmp.name, srt_text
39
+
40
+ except Exception as e:
41
+ logger.error(f"Error during transcription: {str(e)}")
42
+ return None, f"Error: {str(e)}"
43
+
44
+
45
+ def create_ui():
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# 粵文字幕生成器")
48
+ gr.Markdown(
49
+ "上傳一個音頻文件,撳「生成字幕」,過一陣就會得到 SRT 文件。目前支援格式:.mp3、.wav、.flac、.m4a、.ogg、opus、.webm")
50
+
51
+ with gr.Row():
52
+ audio_input = gr.Audio(type="filepath", label="上傳音頻文件或者錄音")
53
+
54
+ with gr.Row():
55
+ generate_btn = gr.Button("生成字幕 SRT 文件", variant="primary", scale=2)
56
+
57
+ with gr.Row():
58
+ with gr.Column():
59
+ preview = gr.Textbox(label="預覽生成字幕", lines=10)
60
+
61
+ with gr.Column():
62
+ output = gr.File(label="下載 SRT")
63
+
64
+ generate_btn.click(
65
+ fn=transcribe_audio,
66
+ inputs=[audio_input],
67
+ outputs=[output, preview]
68
+ )
69
+
70
+ return demo
71
+
72
+
73
+ def main():
74
+ demo = create_ui()
75
+ demo.launch(server_name="0.0.0.0", server_port=8081)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
corrector/Corrector.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import opencc
2
+ from typing import Literal
3
+ import re
4
+
5
+
6
+
7
+ class Corrector:
8
+ """
9
+ SenseVoice model ouputs Simplified Chinese only, this class converts the output to Traditional Chinese
10
+ and fix common Cantonese spelling errors.
11
+ """
12
+
13
+ def __init__(self, corrector: Literal["opencc"] = "opencc"):
14
+ self.corrector = corrector
15
+ self.converter = None
16
+ self.bert_model = None
17
+
18
+ if corrector == "opencc":
19
+ self.converter = opencc.OpenCC("s2hk")
20
+ self.regular_errors: list[tuple[re.Pattern, str]] = [
21
+ (re.compile(r"俾(?!(?:路支|斯麥|益))"), r"畀"),
22
+ (re.compile(r"(?<!(?:聯))[系繫](?!(?:統))"), r"係"),
23
+ (re.compile(r"噶"), r"㗎"),
24
+ (re.compile(r"咁(?=[我你佢就樣就話係啊呀嘅,。])"), r"噉"),
25
+ (re.compile(r"(?<![曝晾])曬(?:[衣太衫褲被命嘢相])"), r"晒"),
26
+ (re.compile(r"(?<=[好])翻(?=[去到嚟])"), r"返"),
27
+ (re.compile(r"<\|\w+\|>"), r""),
28
+ ]
29
+
30
+ def correct(self, text: str) -> str:
31
+ """
32
+ Correct the output text using either a language model or OpenCC
33
+ Args:
34
+ text: Input text to correct
35
+ t2s_char_dict: Dictionary mapping traditional to simplified characters
36
+ lm_model: Either 'opencc' or a LanguageModel instance
37
+ Returns:
38
+ Corrected text string
39
+ """
40
+ text = text.strip()
41
+ if not text: # Early return for empty string
42
+ return text
43
+
44
+ if self.corrector == "opencc":
45
+ return self.opencc_correct(text)
46
+ else:
47
+ raise ValueError("corrector should be either 'opencc' or 'bert'")
48
+
49
+ def opencc_correct(self, text: str) -> str:
50
+ """
51
+ Convert text using OpenCC
52
+ Args:
53
+ text: Input text to convert
54
+ config: OpenCC configuration
55
+ Returns:
56
+ Converted text string
57
+ """
58
+ opencc_text = self.converter.convert(text)
59
+ for pattern, replacement in self.regular_errors:
60
+ opencc_text = pattern.sub(replacement, opencc_text)
61
+
62
+ return opencc_text
corrector/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .Corrector import Corrector
2
+
3
+ # Re-export at package level
4
+ __all__ = ['Corrector']
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OpenCC
2
+ datasets
3
+ flask
4
+ funasr
5
+ funasr_onnx
6
+ librosa
7
+ modelscope
8
+ onnxruntime
9
+ onnxruntime-gpu; sys_platform != 'darwin' and platform_machine != 'arm64' and platform_machine != 'aarch64'
10
+ optimum[onnxruntime]
11
+ psutil
12
+ pysrt
13
+ pytest
14
+ pytubefix
15
+ resampy
16
+ torch
17
+ torchaudio
18
+ transformers[onnx]
transcriber/AutoTranscriber.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from typing import List, Literal
4
+
5
+ import librosa
6
+ import numpy as np
7
+ from funasr import AutoModel
8
+ from resampy.core import resample
9
+ from tqdm.auto import tqdm
10
+
11
+ from corrector.Corrector import Corrector
12
+ from transcriber.TranscribeResult import TranscribeResult
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class AutoTranscriber:
18
+ """
19
+ Transcriber class that uses FunASR's AutoModel for VAD and ASR
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ corrector: Literal["opencc", "bert", None] = None,
25
+ use_denoiser=False,
26
+ with_punct=True,
27
+ offset_in_seconds=-0.25,
28
+ sr=16000,
29
+ ):
30
+ self.corrector = corrector
31
+ self.use_denoiser = use_denoiser
32
+ self.with_punct = with_punct
33
+ self.sr = sr
34
+ self.offset_in_seconds = offset_in_seconds
35
+
36
+ # Initialize models
37
+ self.vad_model = AutoModel(model="fsmn-vad")
38
+ self.asr_model = AutoModel(
39
+ model="iic/SenseVoiceSmall",
40
+ vad_model=None, # We'll handle VAD separately
41
+ punc_model="ct-punc" if with_punct else None,
42
+ ban_emo_unks=True,
43
+ )
44
+
45
+ def transcribe(
46
+ self,
47
+ audio_file: str,
48
+ ) -> List[TranscribeResult]:
49
+ """
50
+ Transcribe audio file to text with timestamps.
51
+
52
+ Args:
53
+ audio_file (str): Path to audio file
54
+
55
+ Returns:
56
+ List[TranscribeResult]: List of transcription results
57
+ """
58
+ # Load and preprocess audio
59
+ speech, sr = librosa.load(audio_file, sr=self.sr)
60
+
61
+ # if self.use_denoiser:
62
+ # logger.info("Denoising speech...")
63
+ # speech, _ = denoiser(speech, sr)
64
+
65
+ if sr != 16_000:
66
+ speech = resample(speech, sr, 16_000,
67
+ filter="kaiser_best", parallel=True)
68
+
69
+ # Get VAD segments
70
+ logger.info("Segmenting speech...")
71
+
72
+ start_time = time.time()
73
+ vad_results = self.vad_model.generate(input=speech)
74
+ logger.info("VAD took %.2f seconds", time.time() - start_time)
75
+
76
+ if not vad_results or not vad_results[0]["value"]:
77
+ return []
78
+
79
+ vad_segments = vad_results[0]["value"]
80
+
81
+ # Process each segment
82
+ results = []
83
+
84
+ start_time = time.time()
85
+ for segment in tqdm(vad_segments, desc="Transcribing"):
86
+ start_sample = int(segment[0] * 16) # Convert ms to samples
87
+ end_sample = int(segment[1] * 16)
88
+ segment_audio = speech[start_sample:end_sample]
89
+
90
+ # Get ASR results for segment
91
+ asr_result = self.asr_model.generate(
92
+ input=segment_audio, language="yue", use_itn=True
93
+ )
94
+
95
+ if not asr_result:
96
+ continue
97
+
98
+ start_time = max(0, segment[0] / 1000.0 + self.offset_in_seconds)
99
+ end_time = segment[1] / 1000.0 + self.offset_in_seconds
100
+
101
+ # Convert ASR result to TranscribeResult format
102
+ segment_result = TranscribeResult(
103
+ text=asr_result[0]["text"],
104
+ start_time=start_time, # Convert ms to seconds
105
+ end_time=end_time,
106
+ )
107
+ results.append(segment_result)
108
+
109
+ logger.info("ASR took %.2f seconds", time.time() - start_time)
110
+
111
+ # Apply Chinese conversion if needed
112
+ start_time = time.time()
113
+ results = self._convert_to_traditional_chinese(results)
114
+ logger.info("Conversion took %.2f seconds", time.time() - start_time)
115
+
116
+ return results
117
+
118
+ def _convert_to_traditional_chinese(
119
+ self, results: List[TranscribeResult]
120
+ ) -> List[TranscribeResult]:
121
+ """Convert simplified Chinese to traditional Chinese"""
122
+ if not results or not self.corrector:
123
+ return results
124
+
125
+ corrector = Corrector(self.corrector)
126
+ if self.corrector == "bert":
127
+ for result in tqdm(
128
+ results, total=len(results), desc="Converting to Traditional Chinese"
129
+ ):
130
+ result.text = corrector.correct(result.text)
131
+ elif self.corrector == "opencc":
132
+ # Use a special delimiter that won't appear in Chinese text
133
+ delimiter = "|||"
134
+ # Concatenate all texts with delimiter
135
+ combined_text = delimiter.join(result.text for result in results)
136
+ # Convert all text at once
137
+ converted_text = corrector.correct(combined_text)
138
+ # Split back into individual results
139
+ converted_parts = converted_text.split(delimiter)
140
+
141
+ # Update results with converted text
142
+ for result, converted in zip(results, converted_parts):
143
+ result.text = converted
144
+
145
+ return results
transcriber/TranscribeResult.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class TranscribeResult:
2
+ """
3
+ Each TranscribeResult object represents one SRT line.
4
+ """
5
+
6
+ def __init__(self, text: str, start_time: float, end_time: float):
7
+ self.text = text
8
+ self.start_time = start_time
9
+ self.end_time = end_time
10
+
11
+ def __str__(self):
12
+ return f"TranscribeResult(text={self.text}, start_time={self.start_time}, end_time={self.end_time})"
13
+
14
+ def __repr__(self):
15
+ return str(self)
transcriber/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .AutoTranscriber import AutoTranscriber
2
+ from .TranscribeResult import TranscribeResult
3
+
4
+ __all__ = [
5
+ "AutoTranscriber",
6
+ "TranscribeResult",
7
+ ]
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ from typing import Iterator
5
+
6
+ from pysrt import SubRipFile, SubRipItem, SubRipTime
7
+ from pytubefix import YouTube
8
+
9
+ from transcriber import TranscribeResult
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def download_youtube_audio(video_id: str) -> str:
15
+ """
16
+ Download audio from YouTube video.
17
+
18
+ Args:
19
+ video_id (str): YouTube video ID.
20
+
21
+ Returns:
22
+ str: Path to the downloaded audio file.
23
+ """
24
+ urls = "https://www.youtube.com/watch?v={}".format(video_id)
25
+
26
+ try:
27
+ # https://github.com/JuanBindez/pytubefix/issues/242#issuecomment-2369067929
28
+ vid = YouTube(urls, "MWEB")
29
+
30
+ if vid.title is None:
31
+ return None
32
+
33
+ audio_download = vid.streams.get_audio_only()
34
+ audio_download.download(
35
+ mp3=True,
36
+ filename=video_id,
37
+ output_path=tempfile.gettempdir(),
38
+ skip_existing=True,
39
+ )
40
+ audio_file = tempfile.gettempdir() + "/" + video_id + ".mp3"
41
+
42
+ return audio_file
43
+
44
+ except Exception as e:
45
+ print(e)
46
+ return None
47
+
48
+
49
+ def to_srt(results: Iterator["TranscribeResult"]) -> str:
50
+ """
51
+ Convert the list of TranscribeResult objects into a SRT file
52
+ """
53
+ srt = SubRipFile()
54
+
55
+ for i, t in enumerate(results):
56
+ start = SubRipTime(seconds=t.start_time)
57
+ end = SubRipTime(seconds=t.end_time)
58
+ item = SubRipItem(index=i, start=start, end=end, text=t.text)
59
+ srt.append(item)
60
+
61
+ temp_file = tempfile.gettempdir() + "/output.srt"
62
+ srt.save(temp_file)
63
+
64
+ with open(temp_file, "r", encoding="utf-8") as f:
65
+ srt_text = f.read()
66
+
67
+ os.remove(temp_file)
68
+
69
+ return srt_text