gpt-omni commited on
Commit
e1adc1c
·
1 Parent(s): 2a8e1b5
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -79,6 +79,7 @@ if not os.path.exists(ckpt_dir):
79
 
80
  snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
81
  whispermodel = whisper.load_model("small").to(device)
 
82
  text_tokenizer = Tokenizer(ckpt_dir)
83
  # fabric = L.Fabric(devices=1, strategy="auto")
84
  config = Config.from_file(ckpt_dir + "/model_config.yaml")
@@ -94,10 +95,10 @@ model.eval()
94
 
95
  @spaces.GPU
96
  def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
97
- with torch.no_grad():
98
- mel = mel.unsqueeze(0).to(device)
99
- # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
100
- audio_feature = whispermodel.embed_audio(mel)[0][:leng]
101
  T = audio_feature.size(0)
102
  input_ids_AA = []
103
  for i in range(7):
 
79
 
80
  snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
81
  whispermodel = whisper.load_model("small").to(device)
82
+ whispermodel.eval()
83
  text_tokenizer = Tokenizer(ckpt_dir)
84
  # fabric = L.Fabric(devices=1, strategy="auto")
85
  config = Config.from_file(ckpt_dir + "/model_config.yaml")
 
95
 
96
  @spaces.GPU
97
  def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
98
+ # with torch.no_grad():
99
+ mel = mel.unsqueeze(0).to(device)
100
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
101
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
102
  T = audio_feature.size(0)
103
  input_ids_AA = []
104
  for i in range(7):