Porjaz commited on
Commit
7801528
1 Parent(s): 8840b4b

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +36 -13
custom_interface_app.py CHANGED
@@ -133,10 +133,11 @@ class ASR(Pretrained):
133
  # Get audio length in seconds
134
  audio_length = len(waveform) / 16000
135
 
136
- if audio_length >= 20:
137
  # split audio every 20 seconds
138
  segments = []
139
- max_duration = 20 * 16000 # Maximum segment duration in samples (20 seconds)
 
140
  num_segments = int(np.ceil(len(waveform) / max_duration))
141
  start = 0
142
  for i in range(num_segments):
@@ -159,7 +160,14 @@ class ASR(Pretrained):
159
 
160
  # Pass the segment through the ASR model
161
  segment_output = self.encode_batch_w2v2(device, batch, rel_length)
162
- yield segment_output
 
 
 
 
 
 
 
163
  else:
164
  waveform = torch.tensor(waveform).to(device)
165
  waveform = waveform.to(device)
@@ -167,7 +175,7 @@ class ASR(Pretrained):
167
  batch = waveform.unsqueeze(0)
168
  rel_length = torch.tensor([1.0]).to(device)
169
  outputs = self.encode_batch_w2v2(device, batch, rel_length)
170
- yield outputs
171
 
172
 
173
 
@@ -179,10 +187,11 @@ class ASR(Pretrained):
179
  # Get audio length in seconds
180
  audio_length = len(waveform) / 16000
181
 
182
- if audio_length >= 20:
183
  # split audio every 20 seconds
184
  segments = []
185
- max_duration = 20 * 16000 # Maximum segment duration in samples (20 seconds)
 
186
  num_segments = int(np.ceil(len(waveform) / max_duration))
187
  start = 0
188
  for i in range(num_segments):
@@ -205,21 +214,28 @@ class ASR(Pretrained):
205
 
206
  # Pass the segment through the ASR model
207
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
208
- yield segment_output
 
 
 
 
 
 
 
209
  else:
210
  waveform = torch.tensor(waveform).to(device)
211
  waveform = waveform.to(device)
212
  batch = waveform.unsqueeze(0)
213
  rel_length = torch.tensor([1.0]).to(device)
214
  outputs = self.encode_batch_whisper(device, batch, rel_length)
215
- yield outputs
216
 
217
 
218
 
219
  def classify_file_whisper(self, waveform, pipe, device):
220
  # waveform, sr = librosa.load(path, sr=16000)
221
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
222
- return transcription
223
 
224
 
225
  def classify_file_mms(self, waveform, processor, model, device):
@@ -229,10 +245,11 @@ class ASR(Pretrained):
229
  # Get audio length in seconds
230
  audio_length = len(waveform) / 16000
231
 
232
- if audio_length >= 20:
233
  # split audio every 20 seconds
234
  segments = []
235
- max_duration = 20 * 16000 # Maximum segment duration in samples (20 seconds)
 
236
  num_segments = int(np.ceil(len(waveform) / max_duration))
237
  start = 0
238
  for i in range(num_segments):
@@ -255,7 +272,13 @@ class ASR(Pretrained):
255
  outputs = model(**inputs).logits
256
  ids = torch.argmax(outputs, dim=-1)[0]
257
  segment_output = processor.decode(ids)
258
- yield segment_output
 
 
 
 
 
 
259
  else:
260
  waveform = torch.tensor(waveform).to(device)
261
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
@@ -263,4 +286,4 @@ class ASR(Pretrained):
263
  outputs = model(**inputs).logits
264
  ids = torch.argmax(outputs, dim=-1)[0]
265
  transcription = processor.decode(ids)
266
- yield transcription
 
133
  # Get audio length in seconds
134
  audio_length = len(waveform) / 16000
135
 
136
+ if audio_length >= 30:
137
  # split audio every 20 seconds
138
  segments = []
139
+ all_segments = []
140
+ max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
141
  num_segments = int(np.ceil(len(waveform) / max_duration))
142
  start = 0
143
  for i in range(num_segments):
 
160
 
161
  # Pass the segment through the ASR model
162
  segment_output = self.encode_batch_w2v2(device, batch, rel_length)
163
+ segment_output = [" ".join(segment) for segment in segment_output]
164
+ all_segments.append(segment_output)
165
+
166
+ segments = ""
167
+ for segment in all_segments:
168
+ segment = segment[0]
169
+ segments += segment + " "
170
+ return [segments]
171
  else:
172
  waveform = torch.tensor(waveform).to(device)
173
  waveform = waveform.to(device)
 
175
  batch = waveform.unsqueeze(0)
176
  rel_length = torch.tensor([1.0]).to(device)
177
  outputs = self.encode_batch_w2v2(device, batch, rel_length)
178
+ return [" ".join(out) for out in outputs]
179
 
180
 
181
 
 
187
  # Get audio length in seconds
188
  audio_length = len(waveform) / 16000
189
 
190
+ if audio_length >= 30:
191
  # split audio every 20 seconds
192
  segments = []
193
+ all_segments = []
194
+ max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
195
  num_segments = int(np.ceil(len(waveform) / max_duration))
196
  start = 0
197
  for i in range(num_segments):
 
214
 
215
  # Pass the segment through the ASR model
216
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
217
+ # segment_output = [" ".join(segment) for segment in segment_output]
218
+ all_segments.append(segment_output)
219
+
220
+ segments = ""
221
+ for segment in all_segments:
222
+ segment = segment[0]
223
+ segments += segment + " "
224
+ return [segments]
225
  else:
226
  waveform = torch.tensor(waveform).to(device)
227
  waveform = waveform.to(device)
228
  batch = waveform.unsqueeze(0)
229
  rel_length = torch.tensor([1.0]).to(device)
230
  outputs = self.encode_batch_whisper(device, batch, rel_length)
231
+ return outputs
232
 
233
 
234
 
235
  def classify_file_whisper(self, waveform, pipe, device):
236
  # waveform, sr = librosa.load(path, sr=16000)
237
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
238
+ return [transcription]
239
 
240
 
241
  def classify_file_mms(self, waveform, processor, model, device):
 
245
  # Get audio length in seconds
246
  audio_length = len(waveform) / 16000
247
 
248
+ if audio_length >= 30:
249
  # split audio every 20 seconds
250
  segments = []
251
+ all_segments = []
252
+ max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
253
  num_segments = int(np.ceil(len(waveform) / max_duration))
254
  start = 0
255
  for i in range(num_segments):
 
272
  outputs = model(**inputs).logits
273
  ids = torch.argmax(outputs, dim=-1)[0]
274
  segment_output = processor.decode(ids)
275
+ # segment_output = [" ".join(segment) for segment in segment_output]
276
+ all_segments.append(segment_output)
277
+
278
+ segments = ""
279
+ for segment in all_segments:
280
+ segments += segment + " "
281
+ return [segments]
282
  else:
283
  waveform = torch.tensor(waveform).to(device)
284
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
 
286
  outputs = model(**inputs).logits
287
  ids = torch.argmax(outputs, dim=-1)[0]
288
  transcription = processor.decode(ids)
289
+ return [transcription]