Spaces:
Runtime error
Runtime error
gpt-omni
commited on
Commit
•
58227c7
1
Parent(s):
369b919
update
Browse files- inference.py +4 -3
inference.py
CHANGED
@@ -138,6 +138,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
|
138 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
139 |
|
140 |
|
|
|
141 |
def load_audio(path):
|
142 |
audio = whisper.load_audio(path)
|
143 |
duration_ms = (len(audio) / 16000) * 1000
|
@@ -357,7 +358,7 @@ def load_model(ckpt_dir, device):
|
|
357 |
config.post_adapter = False
|
358 |
|
359 |
with fabric.init_module(empty_init=False):
|
360 |
-
model = GPT(config)
|
361 |
|
362 |
# model = fabric.setup(model)
|
363 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
@@ -401,8 +402,8 @@ class OmniInference:
|
|
401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
402 |
model = self.model
|
403 |
|
404 |
-
with self.fabric.init_tensor():
|
405 |
-
|
406 |
|
407 |
mel, leng = load_audio(audio_path)
|
408 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
138 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
139 |
|
140 |
|
141 |
+
@spaces.GPU
|
142 |
def load_audio(path):
|
143 |
audio = whisper.load_audio(path)
|
144 |
duration_ms = (len(audio) / 16000) * 1000
|
|
|
358 |
config.post_adapter = False
|
359 |
|
360 |
with fabric.init_module(empty_init=False):
|
361 |
+
model = GPT(config, device=device)
|
362 |
|
363 |
# model = fabric.setup(model)
|
364 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
|
|
402 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
403 |
model = self.model
|
404 |
|
405 |
+
# with self.fabric.init_tensor():
|
406 |
+
model.set_kv_cache(batch_size=2)
|
407 |
|
408 |
mel, leng = load_audio(audio_path)
|
409 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|