chiyo123 commited on
Commit
a39f217
β€’
1 Parent(s): 044d7a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -182
app.py CHANGED
@@ -1,182 +1,150 @@
1
- /* eslint-disable camelcase */
2
- import { pipeline, env } from "@xenova/transformers";
3
-
4
- // Disable local models
5
- env.allowLocalModels = false;
6
-
7
- // Define model factories
8
- // Ensures only one model is created of each type
9
- class PipelineFactory {
10
- static task = null;
11
- static model = null;
12
- static quantized = null;
13
- static instance = null;
14
-
15
- constructor(tokenizer, model, quantized) {
16
- this.tokenizer = tokenizer;
17
- this.model = model;
18
- this.quantized = quantized;
19
- }
20
-
21
- static async getInstance(progress_callback = null) {
22
- if (this.instance === null) {
23
- this.instance = pipeline(this.task, this.model, {
24
- quantized: this.quantized,
25
- progress_callback,
26
-
27
- // For medium models, we need to load the `no_attentions` revision to avoid running out of memory
28
- revision: this.model.includes("/whisper-medium") ? "no_attentions" : "main"
29
- });
30
- }
31
-
32
- return this.instance;
33
- }
34
- }
35
-
36
- self.addEventListener("message", async (event) => {
37
- const message = event.data;
38
-
39
- // Do some work...
40
- // TODO use message data
41
- let transcript = await transcribe(
42
- message.audio,
43
- message.model,
44
- message.multilingual,
45
- message.quantized,
46
- message.subtask,
47
- message.language,
48
- );
49
- if (transcript === null) return;
50
-
51
- // Send the result back to the main thread
52
- self.postMessage({
53
- status: "complete",
54
- task: "automatic-speech-recognition",
55
- data: transcript,
56
- });
57
- });
58
-
59
- class AutomaticSpeechRecognitionPipelineFactory extends PipelineFactory {
60
- static task = "automatic-speech-recognition";
61
- static model = null;
62
- static quantized = null;
63
- }
64
-
65
- const transcribe = async (
66
- audio,
67
- model,
68
- multilingual,
69
- quantized,
70
- subtask,
71
- language,
72
- ) => {
73
-
74
- const isDistilWhisper = model.startsWith("distil-whisper/");
75
-
76
- let modelName = model;
77
- if (!isDistilWhisper && !multilingual) {
78
- modelName += ".en"
79
- }
80
-
81
- const p = AutomaticSpeechRecognitionPipelineFactory;
82
- if (p.model !== modelName || p.quantized !== quantized) {
83
- // Invalidate model if different
84
- p.model = modelName;
85
- p.quantized = quantized;
86
-
87
- if (p.instance !== null) {
88
- (await p.getInstance()).dispose();
89
- p.instance = null;
90
- }
91
- }
92
-
93
- // Load transcriber model
94
- let transcriber = await p.getInstance((data) => {
95
- self.postMessage(data);
96
- });
97
-
98
- const time_precision =
99
- transcriber.processor.feature_extractor.config.chunk_length /
100
- transcriber.model.config.max_source_positions;
101
-
102
- // Storage for chunks to be processed. Initialise with an empty chunk.
103
- let chunks_to_process = [
104
- {
105
- tokens: [],
106
- finalised: false,
107
- },
108
- ];
109
-
110
- // TODO: Storage for fully-processed and merged chunks
111
- // let decoded_chunks = [];
112
-
113
- function chunk_callback(chunk) {
114
- let last = chunks_to_process[chunks_to_process.length - 1];
115
-
116
- // Overwrite last chunk with new info
117
- Object.assign(last, chunk);
118
- last.finalised = true;
119
-
120
- // Create an empty chunk after, if it not the last chunk
121
- if (!chunk.is_last) {
122
- chunks_to_process.push({
123
- tokens: [],
124
- finalised: false,
125
- });
126
- }
127
- }
128
-
129
- // Inject custom callback function to handle merging of chunks
130
- function callback_function(item) {
131
- let last = chunks_to_process[chunks_to_process.length - 1];
132
-
133
- // Update tokens of last chunk
134
- last.tokens = [...item[0].output_token_ids];
135
-
136
- // Merge text chunks
137
- // TODO optimise so we don't have to decode all chunks every time
138
- let data = transcriber.tokenizer._decode_asr(chunks_to_process, {
139
- time_precision: time_precision,
140
- return_timestamps: true,
141
- force_full_sequences: false,
142
- });
143
-
144
- self.postMessage({
145
- status: "update",
146
- task: "automatic-speech-recognition",
147
- data: data,
148
- });
149
- }
150
-
151
- // Actually run transcription
152
- let output = await transcriber(audio, {
153
- // Greedy
154
- top_k: 0,
155
- do_sample: false,
156
-
157
- // Sliding window
158
- chunk_length_s: isDistilWhisper ? 20 : 30,
159
- stride_length_s: isDistilWhisper ? 3 : 5,
160
-
161
- // Language and task
162
- language: language,
163
- task: subtask,
164
-
165
- // Return timestamps
166
- return_timestamps: true,
167
- force_full_sequences: false,
168
-
169
- // Callback functions
170
- callback_function: callback_function, // after each generation step
171
- chunk_callback: chunk_callback, // after each chunk is processed
172
- }).catch((error) => {
173
- self.postMessage({
174
- status: "error",
175
- task: "automatic-speech-recognition",
176
- data: error,
177
- });
178
- return null;
179
- });
180
-
181
- return output;
182
- };
 
1
+ import torch
2
+
3
+ import gradio as gr
4
+ import yt_dlp as youtube_dl
5
+ from transformers import pipeline
6
+ from transformers.pipelines.audio_utils import ffmpeg_read
7
+
8
+ import tempfile
9
+ import os
10
+
11
+ MODEL_NAME = "chiyo123/whisper-small-tonga"
12
+ BATCH_SIZE = 8
13
+ FILE_LIMIT_MB = 1000
14
+ YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
15
+
16
+ device = 0 if torch.cuda.is_available() else "cpu"
17
+
18
+ pipe = pipeline(
19
+ task="automatic-speech-recognition",
20
+ model=MODEL_NAME,
21
+ chunk_length_s=30,
22
+ device=device,
23
+ )
24
+
25
+
26
+ def transcribe(inputs, task):
27
+ if inputs is None:
28
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
29
+
30
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
31
+ return text
32
+
33
+
34
+ def _return_yt_html_embed(yt_url):
35
+ video_id = yt_url.split("?v=")[-1]
36
+ HTML_str = (
37
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
38
+ " </center>"
39
+ )
40
+ return HTML_str
41
+
42
+ def download_yt_audio(yt_url, filename):
43
+ info_loader = youtube_dl.YoutubeDL()
44
+
45
+ try:
46
+ info = info_loader.extract_info(yt_url, download=False)
47
+ except youtube_dl.utils.DownloadError as err:
48
+ raise gr.Error(str(err))
49
+
50
+ file_length = info["duration_string"]
51
+ file_h_m_s = file_length.split(":")
52
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
53
+
54
+ if len(file_h_m_s) == 1:
55
+ file_h_m_s.insert(0, 0)
56
+ if len(file_h_m_s) == 2:
57
+ file_h_m_s.insert(0, 0)
58
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
59
+
60
+ if file_length_s > YT_LENGTH_LIMIT_S:
61
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
62
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
63
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
64
+
65
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
66
+
67
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
68
+ try:
69
+ ydl.download([yt_url])
70
+ except youtube_dl.utils.ExtractorError as err:
71
+ raise gr.Error(str(err))
72
+
73
+
74
+ def yt_transcribe(yt_url, task, max_filesize=75.0):
75
+ html_embed_str = _return_yt_html_embed(yt_url)
76
+
77
+ with tempfile.TemporaryDirectory() as tmpdirname:
78
+ filepath = os.path.join(tmpdirname, "video.mp4")
79
+ download_yt_audio(yt_url, filepath)
80
+ with open(filepath, "rb") as f:
81
+ inputs = f.read()
82
+
83
+ inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
84
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
85
+
86
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
87
+
88
+ return html_embed_str, text
89
+
90
+
91
+ demo = gr.Blocks()
92
+
93
+ mf_transcribe = gr.Interface(
94
+ fn=transcribe,
95
+ inputs=[
96
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True),
97
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
98
+ ],
99
+ outputs="text",
100
+ layout="horizontal",
101
+ theme="huggingface",
102
+ title="Whisper Large V3: Transcribe Audio",
103
+ description=(
104
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
105
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe audio files"
106
+ " of arbitrary length."
107
+ ),
108
+ allow_flagging="never",
109
+ )
110
+
111
+ file_transcribe = gr.Interface(
112
+ fn=transcribe,
113
+ inputs=[
114
+ gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
115
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
116
+ ],
117
+ outputs="text",
118
+ layout="horizontal",
119
+ theme="huggingface",
120
+ title="Whisper Large V3: Transcribe Audio",
121
+ description=(
122
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
123
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe audio files"
124
+ " of arbitrary length."
125
+ ),
126
+ allow_flagging="never",
127
+ )
128
+
129
+ yt_transcribe = gr.Interface(
130
+ fn=yt_transcribe,
131
+ inputs=[
132
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
133
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe")
134
+ ],
135
+ outputs=["html", "text"],
136
+ layout="horizontal",
137
+ theme="huggingface",
138
+ title="Whisper Large V3: Transcribe YouTube",
139
+ description=(
140
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the checkpoint"
141
+ f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe video files of"
142
+ " arbitrary length."
143
+ ),
144
+ allow_flagging="never",
145
+ )
146
+
147
+ with demo:
148
+ gr.TabbedInterface([mf_transcribe, file_transcribe, yt_transcribe], ["Microphone", "Audio file", "YouTube"])
149
+
150
+ demo.launch(enable_queue=True)