Feature Extraction
Transformers
Safetensors
diva
custom_code
WillHeld commited on
Commit
f7f3973
1 Parent(s): 48b107a

Update modeling_diva.py

Browse files
Files changed (1) hide show
  1. modeling_diva.py +17 -5
modeling_diva.py CHANGED
@@ -277,6 +277,8 @@ class DiVAModel(PreTrainedModel):
277
  do_sample=False,
278
  logits_processor=None,
279
  max_new_tokens=128,
 
 
280
  ):
281
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
282
  input_features = inputs.input_features.to(self.whisper_encoder.device)
@@ -305,7 +307,7 @@ class DiVAModel(PreTrainedModel):
305
  [prefix_embed, virt_tokens, suffix_embed], axis=0
306
  ).unsqueeze(0)
307
  outs = []
308
- outputs = None
309
  greedy = 1
310
  i = 0
311
  while greedy != 128009 and len(outs) < max_new_tokens:
@@ -337,9 +339,19 @@ class DiVAModel(PreTrainedModel):
337
  outs.append(greedy)
338
  next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
339
  inputs_embeds = next_embed
340
- yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
 
 
 
 
 
 
 
 
 
341
  "<|eot_id|>", ""
342
  )
343
- return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
344
- "<|eot_id|>", ""
345
- )
 
 
277
  do_sample=False,
278
  logits_processor=None,
279
  max_new_tokens=128,
280
+ return_outputs=False,
281
+ init_outputs=None,
282
  ):
283
  inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
284
  input_features = inputs.input_features.to(self.whisper_encoder.device)
 
307
  [prefix_embed, virt_tokens, suffix_embed], axis=0
308
  ).unsqueeze(0)
309
  outs = []
310
+ outputs = init_outputs
311
  greedy = 1
312
  i = 0
313
  while greedy != 128009 and len(outs) < max_new_tokens:
 
339
  outs.append(greedy)
340
  next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
341
  inputs_embeds = next_embed
342
+ if not return_outputs:
343
+ yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
344
+ "<|eot_id|>", ""
345
+ )
346
+ else:
347
+ yield (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
348
+ "<|eot_id|>", ""
349
+ ), outputs)
350
+ if not return_outputs:
351
+ return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
352
  "<|eot_id|>", ""
353
  )
354
+ else:
355
+ return (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
356
+ "<|eot_id|>", ""
357
+ ), outputs)