ankush13r commited on
Commit
b7fa1b5
1 Parent(s): 8bdc93a

adding new version of whisper

Browse files
Files changed (3) hide show
  1. app.py +9 -7
  2. requirements.txt +1 -0
  3. whisper.py +205 -0
app.py CHANGED
@@ -1,18 +1,18 @@
1
 
2
  import gradio as gr
3
- from whisper2 import generate
4
  from AinaTheme import theme
5
 
6
  MODEL_NAME = "/whisper-large-v3"
 
7
 
8
 
9
-
10
- def transcribe(inputs):
11
  if inputs is None:
12
  raise gr.Error("Cap fitxer d'àudio introduit! Si us plau pengeu un fitxer "\
13
  "o enregistreu un àudio abans d'enviar la vostra sol·licitud")
14
 
15
- return generate(audio=inputs)
16
 
17
 
18
  description_string = "Transcripció automàtica de micròfon o de fitxers d'àudio.\n Aquest demostrador s'ha desenvolupat per"\
@@ -22,7 +22,8 @@ description_string = "Transcripció automàtica de micròfon o de fitxers d'àud
22
 
23
  def clear():
24
  return (
25
- None
 
26
  )
27
 
28
 
@@ -30,6 +31,7 @@ with gr.Blocks(theme=theme) as demo:
30
  gr.Markdown(description_string)
31
  with gr.Row():
32
  with gr.Column(scale=1):
 
33
  input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
34
 
35
  with gr.Column(scale=1):
@@ -40,8 +42,8 @@ with gr.Blocks(theme=theme) as demo:
40
  submit_btn = gr.Button("Submit", variant="primary")
41
 
42
 
43
- submit_btn.click(fn=transcribe, inputs=[input], outputs=[output])
44
- clear_btn.click(fn=clear,inputs=[], outputs=[input], queue=False,)
45
 
46
 
47
  if __name__ == "__main__":
 
1
 
2
  import gradio as gr
3
+ from whisper import generate
4
  from AinaTheme import theme
5
 
6
  MODEL_NAME = "/whisper-large-v3"
7
+ USE_V4 = False
8
 
9
 
10
+ def transcribe(inputs, use_v4):
 
11
  if inputs is None:
12
  raise gr.Error("Cap fitxer d'àudio introduit! Si us plau pengeu un fitxer "\
13
  "o enregistreu un àudio abans d'enviar la vostra sol·licitud")
14
 
15
+ return generate(audio_path=inputs, use_v4=use_v4)
16
 
17
 
18
  description_string = "Transcripció automàtica de micròfon o de fitxers d'àudio.\n Aquest demostrador s'ha desenvolupat per"\
 
22
 
23
  def clear():
24
  return (
25
+ None,
26
+ USE_V4
27
  )
28
 
29
 
 
31
  gr.Markdown(description_string)
32
  with gr.Row():
33
  with gr.Column(scale=1):
34
+ use_v4 = gr.Checkbox(label="Use v4", value=USE_V4)
35
  input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
36
 
37
  with gr.Column(scale=1):
 
42
  submit_btn = gr.Button("Submit", variant="primary")
43
 
44
 
45
+ submit_btn.click(fn=transcribe, inputs=[input, use_v4], outputs=[output])
46
+ clear_btn.click(fn=clear,inputs=[], outputs=[input, use_v4], queue=False,)
47
 
48
 
49
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  git+https://github.com/huggingface/transformers
2
  torch
 
3
  yt-dlp
4
  gradio==4.20.0
5
  torchaudio==2.2.1
 
1
  git+https://github.com/huggingface/transformers
2
  torch
3
+ pyannote.audio
4
  yt-dlp
5
  gradio==4.20.0
6
  torchaudio==2.2.1
whisper.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyannote.audio import Pipeline
2
+ from pydub import AudioSegment
3
+ import os
4
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
+ import torchaudio
6
+ import torch
7
+
8
+ device = 0 if torch.cuda.is_available() else "cpu"
9
+ torch_dtype = torch.float32
10
+
11
+
12
+ MODEL_NAME = "openai/whisper-large-v3"
13
+ model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype).to(device)
14
+ processor = WhisperProcessor.from_pretrained(MODEL_NAME)
15
+ pipeline_vad = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token=os.environ.get("HF_TOKEN"))
16
+ threshold = 15000 # adjust max duration threshold
17
+ segments_dir = "."
18
+
19
+ def clean_text(input_text):
20
+
21
+ remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
22
+ '*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…']
23
+
24
+ output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text) #removing special chars
25
+ return (' '.join(output_text.split()).lower()) #remove extra spaces and return cleaned text
26
+
27
+ def convert_forced_to_tokens(forced_decoder_ids):
28
+ forced_decoder_tokens = []
29
+ for i, (idx, token) in enumerate(forced_decoder_ids):
30
+ if token is not None:
31
+ forced_decoder_tokens.append([idx, processor.tokenizer.decode(token)])
32
+ else:
33
+ forced_decoder_tokens.append([idx, token])
34
+ return forced_decoder_tokens
35
+
36
+ def generate_1st_chunk(audio):
37
+
38
+ input_audio, sample_rate = torchaudio.load(audio)
39
+ input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
40
+
41
+ input_speech = input_audio[0]
42
+
43
+ input_features = processor(input_speech,
44
+ sampling_rate=16_000,
45
+ return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
46
+
47
+ forced_decoder_ids = []
48
+ forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
49
+ forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
50
+ forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
51
+
52
+ forced_decoder_ids_modified = forced_decoder_ids
53
+
54
+ # we need to force these tokens
55
+ forced_decoder_ids = []
56
+
57
+ # now we need to append the prefix tokens (lang, task, timestamps)
58
+ offset = len(forced_decoder_ids)
59
+ for idx, token in forced_decoder_ids_modified:
60
+ forced_decoder_ids.append([idx + offset , token])
61
+
62
+ model.generation_config.forced_decoder_ids = forced_decoder_ids
63
+
64
+ pred_ids = model.generate(input_features,
65
+ return_timestamps=True,
66
+ max_new_tokens=128)
67
+ #exclude prompt from output
68
+ forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
69
+ output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
70
+ output_tokens = processor.batch_decode(pred_ids, skip_special_tokens=False)
71
+
72
+ return output[1:]
73
+
74
+ def generate_from_2nd_chunk(audio, prev_prompt):
75
+
76
+ input_audio, sample_rate = torchaudio.load(audio)
77
+ input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
78
+
79
+ input_speech = input_audio[0]
80
+
81
+ input_features = processor(input_speech,
82
+ sampling_rate=16_000,
83
+ return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
84
+ forced_decoder_ids = []
85
+
86
+ forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
87
+ forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
88
+ forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
89
+
90
+ forced_decoder_ids_modified = forced_decoder_ids
91
+ idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
92
+ forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
93
+
94
+ prompt_tokens = processor.tokenizer(prev_prompt, add_special_tokens=False).input_ids
95
+
96
+ # we need to force these tokens
97
+ forced_decoder_ids = []
98
+ for idx, token in enumerate(prompt_tokens):
99
+ # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
100
+ forced_decoder_ids.append([idx + 1, token])
101
+
102
+ # now we add the SOS token at the end
103
+ offset = len(forced_decoder_ids)
104
+ forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])
105
+
106
+ # now we need to append the rest of the prefix tokens (lang, task, timestamps)
107
+ offset = len(forced_decoder_ids)
108
+ for idx, token in forced_decoder_ids_modified:
109
+ forced_decoder_ids.append([idx + offset , token])
110
+
111
+ model.generation_config.forced_decoder_ids = forced_decoder_ids
112
+
113
+ pred_ids = model.generate(input_features,
114
+ return_timestamps=True,
115
+ max_new_tokens=128,
116
+ decoder_start_token_id=forced_bos_token_id)
117
+ #exclude prompt from output
118
+ forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
119
+ output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
120
+ output_tokens = processor.batch_decode(pred_ids, skip_special_tokens=False)
121
+ return output[1:]
122
+
123
+ def processing_vad_v3(audio, output_vad, prev_prompt):
124
+ transcription_audio = ""
125
+ first_chunk = True
126
+ for speech in output_vad.get_timeline().support():
127
+ start, end = speech.start, speech.end
128
+ segment_audio = audio[start * 1000:end * 1000]
129
+ segment_audio.export(os.path.join(segments_dir, f"temp_segment.wav"), format="wav")
130
+ filename = os.path.join(segments_dir, f"temp_segment.wav")
131
+ if first_chunk:
132
+ output = generate_1st_chunk(filename)
133
+ first_chunk = False
134
+ else:
135
+ output = generate_from_2nd_chunk(filename, prev_prompt)
136
+
137
+ prev_prompt = output
138
+ transcription_audio = transcription_audio + " " + output
139
+
140
+ return transcription_audio
141
+
142
+
143
+ def processing_vad_v4(audio, output_vad, threshold, max_duration, prev_prompt, concatenated_segment):
144
+ transcription_audio = ""
145
+ is_first_chunk = True
146
+ for speech in output_vad.get_timeline().support():
147
+ start, end = speech.start, speech.end
148
+ segment_duration = (end - start) * 1000
149
+ segment_audio = audio[start * 1000:end * 1000]
150
+
151
+ if max_duration + segment_duration < threshold:
152
+ concatenated_segment += audio[start * 1000:end * 1000]
153
+ max_duration += segment_duration
154
+ else:
155
+ if len(concatenated_segment) > 0:
156
+ temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
157
+ concatenated_segment.export(temp_segment_path, format="wav")
158
+
159
+ if is_first_chunk:
160
+ output = generate_1st_chunk(temp_segment_path)
161
+ is_first_chunk = False
162
+ else:
163
+ output = generate_from_2nd_chunk(temp_segment_path, prev_prompt)
164
+
165
+ prev_prompt = output
166
+ transcription_audio = transcription_audio + output
167
+
168
+ max_duration = segment_duration
169
+ concatenated_segment = segment_audio
170
+
171
+ # Process any remaining audio in the concatenated_segment
172
+ if len(concatenated_segment) > 0:
173
+ temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
174
+ concatenated_segment.export(temp_segment_path, format="wav")
175
+
176
+ output = generate_from_2nd_chunk(temp_segment_path, prev_prompt)
177
+
178
+ prev_prompt = output
179
+ transcription_audio = transcription_audio + output
180
+
181
+ return transcription_audio
182
+
183
+
184
+ def generate(audio_path, use_v4):
185
+ #check audio lenght
186
+ audio = AudioSegment.from_wav(audio_path)
187
+ duration_seconds = len(audio) / 1000.0
188
+
189
+ #apply VAD only if the duration is >30s
190
+ if duration_seconds >= 30:
191
+
192
+ output_vad = pipeline_vad(audio_path)
193
+ concatenated_segment = AudioSegment.empty()
194
+ max_duration = 0
195
+ prev_prompt = ""
196
+ if use_v4:
197
+ return processing_vad_v4(audio, output_vad, threshold, max_duration, prev_prompt, concatenated_segment)
198
+ else:
199
+ return processing_vad_v3(audio, output_vad, prev_prompt)
200
+ else:
201
+ #if duraion is <30s, process directly with generate
202
+ return generate_1st_chunk(audio_path)
203
+
204
+
205
+