Spaces:
Running
Running
Update inference.py
Browse files- inference.py +14 -14
inference.py
CHANGED
@@ -147,8 +147,8 @@ def load_audio(path):
|
|
147 |
|
148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
149 |
snacmodel, out_dir=None):
|
150 |
-
|
151 |
-
|
152 |
tokenlist = generate_TA_BATCH(
|
153 |
model,
|
154 |
audio_feature,
|
@@ -191,8 +191,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
|
|
191 |
|
192 |
|
193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
194 |
-
|
195 |
-
|
196 |
tokenlist = generate_AT(
|
197 |
model,
|
198 |
audio_feature,
|
@@ -214,8 +214,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
214 |
|
215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
216 |
snacmodel, out_dir=None):
|
217 |
-
|
218 |
-
|
219 |
tokenlist = generate_AA(
|
220 |
model,
|
221 |
audio_feature,
|
@@ -256,8 +256,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
|
256 |
|
257 |
|
258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
259 |
-
|
260 |
-
|
261 |
tokenlist = generate_ASR(
|
262 |
model,
|
263 |
audio_feature,
|
@@ -280,8 +280,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
280 |
|
281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
282 |
snacmodel, out_dir=None):
|
283 |
-
|
284 |
-
|
285 |
tokenlist = generate_TA(
|
286 |
model,
|
287 |
None,
|
@@ -325,8 +325,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
|
325 |
|
326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
327 |
|
328 |
-
|
329 |
-
|
330 |
tokenlist = generate_TT(
|
331 |
model,
|
332 |
None,
|
@@ -386,6 +386,7 @@ class OmniInference:
|
|
386 |
pass
|
387 |
|
388 |
@torch.inference_mode()
|
|
|
389 |
def run_AT_batch_stream(self,
|
390 |
audio_path,
|
391 |
stream_stride=4,
|
@@ -400,8 +401,7 @@ class OmniInference:
|
|
400 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
401 |
model = self.model
|
402 |
|
403 |
-
|
404 |
-
model.set_kv_cache(batch_size=2)
|
405 |
|
406 |
mel, leng = load_audio(audio_path)
|
407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
147 |
|
148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
149 |
snacmodel, out_dir=None):
|
150 |
+
|
151 |
+
model.set_kv_cache(batch_size=2)
|
152 |
tokenlist = generate_TA_BATCH(
|
153 |
model,
|
154 |
audio_feature,
|
|
|
191 |
|
192 |
|
193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
194 |
+
|
195 |
+
model.set_kv_cache(batch_size=1)
|
196 |
tokenlist = generate_AT(
|
197 |
model,
|
198 |
audio_feature,
|
|
|
214 |
|
215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
216 |
snacmodel, out_dir=None):
|
217 |
+
|
218 |
+
model.set_kv_cache(batch_size=1)
|
219 |
tokenlist = generate_AA(
|
220 |
model,
|
221 |
audio_feature,
|
|
|
256 |
|
257 |
|
258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
259 |
+
|
260 |
+
model.set_kv_cache(batch_size=1)
|
261 |
tokenlist = generate_ASR(
|
262 |
model,
|
263 |
audio_feature,
|
|
|
280 |
|
281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
282 |
snacmodel, out_dir=None):
|
283 |
+
|
284 |
+
model.set_kv_cache(batch_size=1)
|
285 |
tokenlist = generate_TA(
|
286 |
model,
|
287 |
None,
|
|
|
325 |
|
326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
327 |
|
328 |
+
|
329 |
+
model.set_kv_cache(batch_size=1)
|
330 |
tokenlist = generate_TT(
|
331 |
model,
|
332 |
None,
|
|
|
386 |
pass
|
387 |
|
388 |
@torch.inference_mode()
|
389 |
+
@spaces.GPU
|
390 |
def run_AT_batch_stream(self,
|
391 |
audio_path,
|
392 |
stream_stride=4,
|
|
|
401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
402 |
model = self.model
|
403 |
|
404 |
+
model.set_kv_cache(batch_size=2)
|
|
|
405 |
|
406 |
mel, leng = load_audio(audio_path)
|
407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|