amazingvince commited on
Commit
4d0ea8d
1 Parent(s): c6d61ef

Update modeling_diff_llama.py

Browse files
Files changed (1) hide show
  1. modeling_diff_llama.py +21 -0
modeling_diff_llama.py CHANGED
@@ -506,5 +506,26 @@ class DiffLLaMAForCausalLM(PreTrainedModel):
506
  hidden_states=outputs.hidden_states,
507
  attentions=outputs.attentions,
508
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
 
 
506
  hidden_states=outputs.hidden_states,
507
  attentions=outputs.attentions,
508
  )
509
+ def prepare_inputs_for_generation(
510
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
511
+ ):
512
+ if past_key_values:
513
+ input_ids = input_ids[:, -1:]
514
+
515
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
516
+ if inputs_embeds is not None and past_key_values is None:
517
+ model_inputs = {"inputs_embeds": inputs_embeds}
518
+ else:
519
+ model_inputs = {"input_ids": input_ids}
520
+
521
+ model_inputs.update(
522
+ {
523
+ "past_key_values": past_key_values,
524
+ "use_cache": kwargs.get("use_cache"),
525
+ "attention_mask": attention_mask,
526
+ "cache_position": kwargs.get("cache_position"),
527
+ }
528
+ )
529
+ return model_inputs
530
 
531