Porjaz commited on
Commit
49b259b
1 Parent(s): 89424ee

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +42 -25
custom_interface_app.py CHANGED
@@ -231,48 +231,65 @@ class ASR(Pretrained):
231
  return outputs
232
 
233
 
234
- def classify_file_whisper_mkd_streaming(self, waveform, device):
235
- # Load the audio file
236
- # waveform, sr = librosa.load(path, sr=16000)
237
-
238
  # Get audio length in seconds
239
- audio_length = len(waveform) / 16000
 
240
 
241
  if audio_length >= 20:
242
- # split audio every 20 seconds
 
 
 
 
243
  segments = []
244
- max_duration = 20 * 16000 # Maximum segment duration in samples (20 seconds)
245
- num_segments = int(np.ceil(len(waveform) / max_duration))
246
- start = 0
247
- for i in range(num_segments):
248
- end = start + max_duration
249
- if end > len(waveform):
250
- end = len(waveform)
251
  segment_part = waveform[start:end]
252
- segment_len = len(segment_part) / 16000
253
- if segment_len < 1:
254
- continue
255
- segments.append(segment_part)
256
- start = end
257
 
258
- for segment in segments:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  segment_tensor = torch.tensor(segment).to(device)
260
 
261
  # Fake a batch for the segment
262
  batch = segment_tensor.unsqueeze(0).to(device)
263
- rel_length = torch.tensor([1.0]).to(device)
264
 
265
  # Pass the segment through the ASR model
266
- segment_output = self.encode_batch_whisper(device, batch, rel_length)
267
- yield segment_output
268
  else:
269
  waveform = torch.tensor(waveform).to(device)
270
  waveform = waveform.to(device)
 
271
  batch = waveform.unsqueeze(0)
272
  rel_length = torch.tensor([1.0]).to(device)
273
- outputs = self.encode_batch_whisper(device, batch, rel_length)
274
- yield outputs
275
-
276
 
277
  def classify_file_whisper(self, waveform, pipe, device):
278
  # waveform, sr = librosa.load(path, sr=16000)
 
231
  return outputs
232
 
233
 
234
+ def classify_file_w2v2(self, waveform, device):
 
 
 
235
  # Get audio length in seconds
236
+ sr = 16000
237
+ audio_length = len(waveform) / sr
238
 
239
  if audio_length >= 20:
240
+ print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
241
+ # Detect non-silent segments
242
+
243
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
244
+
245
  segments = []
246
+ current_segment = []
247
+ current_length = 0
248
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
249
+
250
+
251
+ for interval in non_silent_intervals:
252
+ start, end = interval
253
  segment_part = waveform[start:end]
 
 
 
 
 
254
 
255
+ # If adding the next part exceeds max duration, store the segment and start a new one
256
+ if current_length + len(segment_part) > max_duration:
257
+ segments.append(np.concatenate(current_segment))
258
+ current_segment = []
259
+ current_length = 0
260
+
261
+ current_segment.append(segment_part)
262
+ current_length += len(segment_part)
263
+
264
+ # Append the last segment if it's not empty
265
+ if current_segment:
266
+ segments.append(np.concatenate(current_segment))
267
+
268
+ # Process each segment
269
+ outputs = []
270
+ for i, segment in enumerate(segments):
271
+ print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
272
+
273
+ # import soundfile as sf
274
+ # sf.write(f"outputs/segment_{i}.wav", segment, sr)
275
+
276
  segment_tensor = torch.tensor(segment).to(device)
277
 
278
  # Fake a batch for the segment
279
  batch = segment_tensor.unsqueeze(0).to(device)
280
+ rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
281
 
282
  # Pass the segment through the ASR model
283
+ outputs.append(self.encode_batch_w2v2(device, batch, rel_length))
284
+ return outputs
285
  else:
286
  waveform = torch.tensor(waveform).to(device)
287
  waveform = waveform.to(device)
288
+ # Fake a batch:
289
  batch = waveform.unsqueeze(0)
290
  rel_length = torch.tensor([1.0]).to(device)
291
+ outputs = self.encode_batch_w2v2(device, batch, rel_length)
292
+ return [outputs]
 
293
 
294
  def classify_file_whisper(self, waveform, pipe, device):
295
  # waveform, sr = librosa.load(path, sr=16000)