ssolito commited on
Commit
fd449bb
·
verified ·
1 Parent(s): 57f73ef

Update whisper.py

Browse files
Files changed (1) hide show
  1. whisper.py +4 -175
whisper.py CHANGED
@@ -9,33 +9,12 @@ device = 0 if torch.cuda.is_available() else "cpu"
9
  torch_dtype = torch.float32
10
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
- #MODEL_NAME = "openai/whisper-large-v3"
13
  MODEL_NAME = "projecte-aina/whisper-large-v3-ca-es-synth-cs"
14
  model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype,token=HF_TOKEN).to(device)
15
  processor = WhisperProcessor.from_pretrained(MODEL_NAME)
16
- pipeline_vad = Pipeline.from_pretrained("./pyannote/config.yaml")
17
- threshold = 15000 # adjust max duration threshold
18
- segments_dir = "."
19
 
20
- def clean_text(input_text):
21
-
22
- remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
23
- '*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…']
24
-
25
- output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text) #removing special chars
26
- return (' '.join(output_text.split()).lower()) #remove extra spaces and return cleaned text
27
-
28
- def convert_forced_to_tokens(forced_decoder_ids):
29
- forced_decoder_tokens = []
30
- for i, (idx, token) in enumerate(forced_decoder_ids):
31
- if token is not None:
32
- forced_decoder_tokens.append([idx, processor.tokenizer.decode(token)])
33
- else:
34
- forced_decoder_tokens.append([idx, token])
35
- return forced_decoder_tokens
36
-
37
- def generate_1st_chunk(audio):
38
 
 
39
  input_audio, sample_rate = torchaudio.load(audio)
40
  input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
41
 
@@ -44,161 +23,11 @@ def generate_1st_chunk(audio):
44
  input_features = processor(input_speech,
45
  sampling_rate=16_000,
46
  return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
47
-
48
- forced_decoder_ids = []
49
- forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
50
- forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
51
- forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
52
-
53
- forced_decoder_ids_modified = forced_decoder_ids
54
-
55
- # we need to force these tokens
56
- forced_decoder_ids = []
57
-
58
- # now we need to append the prefix tokens (lang, task, timestamps)
59
- offset = len(forced_decoder_ids)
60
- for idx, token in forced_decoder_ids_modified:
61
- forced_decoder_ids.append([idx + offset , token])
62
 
63
- model.generation_config.forced_decoder_ids = forced_decoder_ids
64
-
65
  pred_ids = model.generate(input_features,
66
  return_timestamps=True,
67
  max_new_tokens=128)
68
- #exclude prompt from output
69
- forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
70
- output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
71
-
72
- return output[1:]
73
-
74
- def generate_from_2nd_chunk(audio, prev_prompt):
75
-
76
- input_audio, sample_rate = torchaudio.load(audio)
77
- input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
78
-
79
- input_speech = input_audio[0]
80
-
81
- input_features = processor(input_speech,
82
- sampling_rate=16_000,
83
- return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
84
- forced_decoder_ids = []
85
-
86
- forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
87
- forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
88
- forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
89
-
90
- forced_decoder_ids_modified = forced_decoder_ids
91
- idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
92
- forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
93
-
94
- prompt_tokens = processor.tokenizer(prev_prompt, add_special_tokens=False).input_ids
95
-
96
- # we need to force these tokens
97
- forced_decoder_ids = []
98
- for idx, token in enumerate(prompt_tokens):
99
- # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
100
- forced_decoder_ids.append([idx + 1, token])
101
-
102
- # now we add the SOS token at the end
103
- offset = len(forced_decoder_ids)
104
- forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])
105
-
106
- # now we need to append the rest of the prefix tokens (lang, task, timestamps)
107
- offset = len(forced_decoder_ids)
108
- for idx, token in forced_decoder_ids_modified:
109
- forced_decoder_ids.append([idx + offset , token])
110
-
111
- model.generation_config.forced_decoder_ids = forced_decoder_ids
112
-
113
- pred_ids = model.generate(input_features,
114
- return_timestamps=True,
115
- max_new_tokens=128,
116
- decoder_start_token_id=forced_bos_token_id)
117
- #exclude prompt from output
118
- forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
119
- output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
120
- return output[1:]
121
-
122
- def processing_vad_v3(audio, output_vad, prev_prompt):
123
- transcription_audio = ""
124
- first_chunk = True
125
- for speech in output_vad.get_timeline().support():
126
- start, end = speech.start, speech.end
127
- segment_audio = audio[start * 1000:end * 1000]
128
- filename = os.path.join(segments_dir, f"temp_segment.wav")
129
- segment_audio.export(filename, format="wav")
130
- if first_chunk:
131
- output = generate_1st_chunk(filename)
132
- first_chunk = False
133
- else:
134
- output = generate_from_2nd_chunk(filename, prev_prompt)
135
-
136
- prev_prompt = output
137
- transcription_audio = transcription_audio + " " + output
138
-
139
- return transcription_audio
140
-
141
-
142
- def processing_vad_v4(audio, output_vad, threshold, max_duration, prev_prompt, concatenated_segment):
143
- transcription_audio = ""
144
- is_first_chunk = True
145
- for speech in output_vad.get_timeline().support():
146
- start, end = speech.start, speech.end
147
- segment_duration = (end - start) * 1000
148
- segment_audio = audio[start * 1000:end * 1000]
149
-
150
- if max_duration + segment_duration < threshold:
151
- concatenated_segment += audio[start * 1000:end * 1000]
152
- max_duration += segment_duration
153
- else:
154
- if len(concatenated_segment) > 0:
155
- temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
156
- concatenated_segment.export(temp_segment_path, format="wav")
157
-
158
- if is_first_chunk:
159
- output = generate_1st_chunk(temp_segment_path)
160
- is_first_chunk = False
161
- else:
162
- output = generate_from_2nd_chunk(temp_segment_path, prev_prompt)
163
-
164
- prev_prompt = output
165
- transcription_audio = transcription_audio + output
166
-
167
- max_duration = segment_duration
168
- concatenated_segment = segment_audio
169
-
170
- # Process any remaining audio in the concatenated_segment
171
- if len(concatenated_segment) > 0:
172
- temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
173
- concatenated_segment.export(temp_segment_path, format="wav")
174
-
175
- output = generate_from_2nd_chunk(temp_segment_path, prev_prompt)
176
-
177
- prev_prompt = output
178
- transcription_audio = transcription_audio + output
179
-
180
- return transcription_audio
181
-
182
-
183
- def generate(audio_path, use_v4):
184
- #check audio lenght
185
- audio = AudioSegment.from_wav(audio_path)
186
- duration_seconds = len(audio) / 1000.0
187
-
188
- #apply VAD only if the duration is >30s
189
- if duration_seconds >= 30:
190
-
191
- output_vad = pipeline_vad(audio_path)
192
- concatenated_segment = AudioSegment.empty()
193
- max_duration = 0
194
- prev_prompt = ""
195
- if use_v4:
196
- return processing_vad_v4(audio, output_vad, threshold, max_duration, prev_prompt, concatenated_segment)
197
- else:
198
- return processing_vad_v3(audio, output_vad, prev_prompt)
199
- else:
200
- #if duraion is <30s, process directly with generate
201
- return generate_1st_chunk(audio_path)
202
-
203
 
204
-
 
 
 
9
  torch_dtype = torch.float32
10
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
 
12
  MODEL_NAME = "projecte-aina/whisper-large-v3-ca-es-synth-cs"
13
  model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype,token=HF_TOKEN).to(device)
14
  processor = WhisperProcessor.from_pretrained(MODEL_NAME)
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def generate(audio_path):
18
  input_audio, sample_rate = torchaudio.load(audio)
19
  input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
20
 
 
23
  input_features = processor(input_speech,
24
  sampling_rate=16_000,
25
  return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
27
  pred_ids = model.generate(input_features,
28
  return_timestamps=True,
29
  max_new_tokens=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ output = processor.batch_decode(pred_ids, skip_special_tokens=True)
32
+ line = output[0]
33
+ return line