ACloudCenter commited on
Commit
e96a4b0
·
1 Parent(s): 1bfebf7

fix: Use batch loading with dynamiccutsampler to avoid audio shape errors. Borrow from Nvidia example to test

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  import spaces
4
  from lhotse import Recording
 
5
  from nemo.collections.speechlm2 import SALM
6
 
7
  # Set device to use cuda if available and sample rate to 16000 for Nvidia NeMo
@@ -26,15 +27,16 @@ def transcribe_audio(audio_filepath):
26
  cut = cut.to_mono(mono_downmix=True)
27
 
28
  # Load audio data
29
- audio = cut.load_audio()
30
- audio_lens = audio.shape[0]
31
-
 
32
  # Generate transcription
33
  with torch.inference_mode():
34
  output_ids = model.generate(
35
- prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]], # torch.as_tensor is used to convert the audio data to a tensor for model input
36
- audios=torch.as_tensor(audio).unsqueeze(0).to(device),
37
- audio_lens=torch.as_tensor([audio_lens]).to(device), # torch.as_tensor is used to convert the audio length to a tensor for model input
38
  max_new_tokens=256, # Maximum number of tokens to generate
39
  )
40
 
 
2
  import torch
3
  import spaces
4
  from lhotse import Recording
5
+ from lhotse.dataset import DynamicCutSampler
6
  from nemo.collections.speechlm2 import SALM
7
 
8
  # Set device to use cuda if available and sample rate to 16000 for Nvidia NeMo
 
27
  cut = cut.to_mono(mono_downmix=True)
28
 
29
  # Load audio data
30
+ batch = DynamicCutSampler([cut], max_cuts=1)
31
+ for b in batch:
32
+ audio, audio_lens = b.load_audio(collate=True)
33
+
34
  # Generate transcription
35
  with torch.inference_mode():
36
  output_ids = model.generate(
37
+ prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]],
38
+ audios=torch.as_tensor(audio).to(device),
39
+ audio_lens=torch.as_tensor(audio_lens).to(device),
40
  max_new_tokens=256, # Maximum number of tokens to generate
41
  )
42