SmerkyG commited on
Commit
d420145
·
verified ·
1 Parent(s): 2b07ca9

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +4 -1
modeling_rwkv5.py CHANGED
@@ -789,7 +789,10 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
789
  # only last token for inputs_ids if the state is passed along.
790
  if state is not None:
791
  input_ids = input_ids[:, -1].unsqueeze(-1)
792
-
 
 
 
793
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
794
  if inputs_embeds is not None and state is None:
795
  model_inputs = {"inputs_embeds": inputs_embeds}
 
789
  # only last token for inputs_ids if the state is passed along.
790
  if state is not None:
791
  input_ids = input_ids[:, -1].unsqueeze(-1)
792
+ else:
793
+ # add in \n at the beginning
794
+ input_ids = torch.cat([torch.full([1,1],11,device=input_ids.device,dtype=input_ids.dtype), input_ids])
795
+
796
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
797
  if inputs_embeds is not None and state is None:
798
  model_inputs = {"inputs_embeds": inputs_embeds}