Update modeling_diva.py
Browse files- modeling_diva.py +17 -5
modeling_diva.py
CHANGED
@@ -277,6 +277,8 @@ class DiVAModel(PreTrainedModel):
|
|
277 |
do_sample=False,
|
278 |
logits_processor=None,
|
279 |
max_new_tokens=128,
|
|
|
|
|
280 |
):
|
281 |
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
|
282 |
input_features = inputs.input_features.to(self.whisper_encoder.device)
|
@@ -305,7 +307,7 @@ class DiVAModel(PreTrainedModel):
|
|
305 |
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
306 |
).unsqueeze(0)
|
307 |
outs = []
|
308 |
-
outputs =
|
309 |
greedy = 1
|
310 |
i = 0
|
311 |
while greedy != 128009 and len(outs) < max_new_tokens:
|
@@ -337,9 +339,19 @@ class DiVAModel(PreTrainedModel):
|
|
337 |
outs.append(greedy)
|
338 |
next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
|
339 |
inputs_embeds = next_embed
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
"<|eot_id|>", ""
|
342 |
)
|
343 |
-
|
344 |
-
|
345 |
-
|
|
|
|
277 |
do_sample=False,
|
278 |
logits_processor=None,
|
279 |
max_new_tokens=128,
|
280 |
+
return_outputs=False,
|
281 |
+
init_outputs=None,
|
282 |
):
|
283 |
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
|
284 |
input_features = inputs.input_features.to(self.whisper_encoder.device)
|
|
|
307 |
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
308 |
).unsqueeze(0)
|
309 |
outs = []
|
310 |
+
outputs = init_outputs
|
311 |
greedy = 1
|
312 |
i = 0
|
313 |
while greedy != 128009 and len(outs) < max_new_tokens:
|
|
|
339 |
outs.append(greedy)
|
340 |
next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
|
341 |
inputs_embeds = next_embed
|
342 |
+
if not return_outputs:
|
343 |
+
yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
344 |
+
"<|eot_id|>", ""
|
345 |
+
)
|
346 |
+
else:
|
347 |
+
yield (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
348 |
+
"<|eot_id|>", ""
|
349 |
+
), outputs)
|
350 |
+
if not return_outputs:
|
351 |
+
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
352 |
"<|eot_id|>", ""
|
353 |
)
|
354 |
+
else:
|
355 |
+
return (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
356 |
+
"<|eot_id|>", ""
|
357 |
+
), outputs)
|