gpt-omni commited on
Commit
58c8b03
1 Parent(s): 6683bb4

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -14
inference.py CHANGED
@@ -147,8 +147,8 @@ def load_audio(path):
147
 
148
  def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
149
  snacmodel, out_dir=None):
150
- with fabric.init_tensor():
151
- model.set_kv_cache(batch_size=2)
152
  tokenlist = generate_TA_BATCH(
153
  model,
154
  audio_feature,
@@ -191,8 +191,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
191
 
192
 
193
  def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
194
- with fabric.init_tensor():
195
- model.set_kv_cache(batch_size=1)
196
  tokenlist = generate_AT(
197
  model,
198
  audio_feature,
@@ -214,8 +214,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
214
 
215
  def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
216
  snacmodel, out_dir=None):
217
- with fabric.init_tensor():
218
- model.set_kv_cache(batch_size=1)
219
  tokenlist = generate_AA(
220
  model,
221
  audio_feature,
@@ -256,8 +256,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
256
 
257
 
258
  def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
259
- with fabric.init_tensor():
260
- model.set_kv_cache(batch_size=1)
261
  tokenlist = generate_ASR(
262
  model,
263
  audio_feature,
@@ -280,8 +280,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
280
 
281
  def T1_A2(fabric, input_ids, model, text_tokenizer, step,
282
  snacmodel, out_dir=None):
283
- with fabric.init_tensor():
284
- model.set_kv_cache(batch_size=1)
285
  tokenlist = generate_TA(
286
  model,
287
  None,
@@ -325,8 +325,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
325
 
326
  def T1_T2(fabric, input_ids, model, text_tokenizer, step):
327
 
328
- with fabric.init_tensor():
329
- model.set_kv_cache(batch_size=1)
330
  tokenlist = generate_TT(
331
  model,
332
  None,
@@ -386,6 +386,7 @@ class OmniInference:
386
  pass
387
 
388
  @torch.inference_mode()
 
389
  def run_AT_batch_stream(self,
390
  audio_path,
391
  stream_stride=4,
@@ -400,8 +401,7 @@ class OmniInference:
400
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
401
  model = self.model
402
 
403
- with self.fabric.init_tensor():
404
- model.set_kv_cache(batch_size=2)
405
 
406
  mel, leng = load_audio(audio_path)
407
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
 
147
 
148
  def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
149
  snacmodel, out_dir=None):
150
+
151
+ model.set_kv_cache(batch_size=2)
152
  tokenlist = generate_TA_BATCH(
153
  model,
154
  audio_feature,
 
191
 
192
 
193
  def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
194
+
195
+ model.set_kv_cache(batch_size=1)
196
  tokenlist = generate_AT(
197
  model,
198
  audio_feature,
 
214
 
215
  def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
216
  snacmodel, out_dir=None):
217
+
218
+ model.set_kv_cache(batch_size=1)
219
  tokenlist = generate_AA(
220
  model,
221
  audio_feature,
 
256
 
257
 
258
  def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
259
+
260
+ model.set_kv_cache(batch_size=1)
261
  tokenlist = generate_ASR(
262
  model,
263
  audio_feature,
 
280
 
281
  def T1_A2(fabric, input_ids, model, text_tokenizer, step,
282
  snacmodel, out_dir=None):
283
+
284
+ model.set_kv_cache(batch_size=1)
285
  tokenlist = generate_TA(
286
  model,
287
  None,
 
325
 
326
  def T1_T2(fabric, input_ids, model, text_tokenizer, step):
327
 
328
+
329
+ model.set_kv_cache(batch_size=1)
330
  tokenlist = generate_TT(
331
  model,
332
  None,
 
386
  pass
387
 
388
  @torch.inference_mode()
389
+ @spaces.GPU
390
  def run_AT_batch_stream(self,
391
  audio_path,
392
  stream_stride=4,
 
401
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
402
  model = self.model
403
 
404
+ model.set_kv_cache(batch_size=2)
 
405
 
406
  mel, leng = load_audio(audio_path)
407
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)