Porjaz commited on
Commit
6a73d74
·
verified ·
1 Parent(s): 7801528

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +42 -0
custom_interface_app.py CHANGED
@@ -231,6 +231,48 @@ class ASR(Pretrained):
231
  return outputs
232
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  def classify_file_whisper(self, waveform, pipe, device):
236
  # waveform, sr = librosa.load(path, sr=16000)
 
231
  return outputs
232
 
233
 
234
+ def classify_file_whisper_mkd_streaming(self, waveform, device):
235
+ # Load the audio file
236
+ # waveform, sr = librosa.load(path, sr=16000)
237
+
238
+ # Get audio length in seconds
239
+ audio_length = len(waveform) / 16000
240
+
241
+ if audio_length >= 20:
242
+ # split audio every 20 seconds
243
+ segments = []
244
+ max_duration = 20 * 16000 # Maximum segment duration in samples (20 seconds)
245
+ num_segments = int(np.ceil(len(waveform) / max_duration))
246
+ start = 0
247
+ for i in range(num_segments):
248
+ end = start + max_duration
249
+ if end > len(waveform):
250
+ end = len(waveform)
251
+ segment_part = waveform[start:end]
252
+ segment_len = len(segment_part) / 16000
253
+ if segment_len < 1:
254
+ continue
255
+ segments.append(segment_part)
256
+ start = end
257
+
258
+ for segment in segments:
259
+ segment_tensor = torch.tensor(segment).to(device)
260
+
261
+ # Fake a batch for the segment
262
+ batch = segment_tensor.unsqueeze(0).to(device)
263
+ rel_length = torch.tensor([1.0]).to(device)
264
+
265
+ # Pass the segment through the ASR model
266
+ segment_output = self.encode_batch_whisper(device, batch, rel_length)
267
+ yield segment_output
268
+ else:
269
+ waveform = torch.tensor(waveform).to(device)
270
+ waveform = waveform.to(device)
271
+ batch = waveform.unsqueeze(0)
272
+ rel_length = torch.tensor([1.0]).to(device)
273
+ outputs = self.encode_batch_whisper(device, batch, rel_length)
274
+ yield outputs
275
+
276
 
277
  def classify_file_whisper(self, waveform, pipe, device):
278
  # waveform, sr = librosa.load(path, sr=16000)