BucketOfFish commited on
Commit
c07c430
1 Parent(s): 455129a

Passing KV cache through iterations

Browse files
Files changed (2) hide show
  1. phi2_model.py +3 -2
  2. streaming_inference.py +2 -2
phi2_model.py CHANGED
@@ -35,10 +35,11 @@ class Phi2PreTrainedModel(PreTrainedModel):
35
  def prepare_inputs_for_generation(
36
  self,
37
  input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
38
- kv_cache: KVCache | None = None,
39
  key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
40
  **kwargs, # has to be here
41
  ) -> dict[str, Any]:
 
42
  if not kv_cache:
43
  kv_cache = KVCache(
44
  max_seqlen=self.config.initial_cos_sin_cache_len,
@@ -160,4 +161,4 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
160
  if labels is not None
161
  else None
162
  )
163
- return CausalLMOutputWithPast(loss=loss, logits=logits)
 
35
  def prepare_inputs_for_generation(
36
  self,
37
  input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
38
+ past_key_values: KVCache | None = None, # has to be named this
39
  key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
40
  **kwargs, # has to be here
41
  ) -> dict[str, Any]:
42
+ kv_cache = past_key_values
43
  if not kv_cache:
44
  kv_cache = KVCache(
45
  max_seqlen=self.config.initial_cos_sin_cache_len,
 
161
  if labels is not None
162
  else None
163
  )
164
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=kv_cache)
streaming_inference.py CHANGED
@@ -43,11 +43,11 @@ if __name__ == "__main__":
43
  thread = Thread(
44
  target=model.generate,
45
  kwargs=dict(
46
- tokenizer( # returns a torch dictionary
47
  "Here is an essay on sea monkeys: ",
48
  return_tensors="pt",
49
  return_attention_mask=False,
50
- ).to(device),
51
  streamer=token_streamer,
52
  max_new_tokens=500,
53
  eos_token_id=tokenizer.eos_token_id,
 
43
  thread = Thread(
44
  target=model.generate,
45
  kwargs=dict(
46
+ inputs=tokenizer( # returns a torch dictionary
47
  "Here is an essay on sea monkeys: ",
48
  return_tensors="pt",
49
  return_attention_mask=False,
50
+ ).to(device),
51
  streamer=token_streamer,
52
  max_new_tokens=500,
53
  eos_token_id=tokenizer.eos_token_id,