jupyterjazz commited on
Commit
3f2b684
1 Parent(s): 215a6e1

fix: remove prompt length from args

Browse files

Signed-off-by: saba.sturua@jina.ai <saba.sturua@jina.ai>

Files changed (1) hide show
  1. custom_st.py +1 -0
custom_st.py CHANGED
@@ -139,6 +139,7 @@ class Transformer(nn.Module):
139
  lora_arguments = (
140
  {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
141
  )
 
142
  output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
143
  output_tokens = output_states[0]
144
  features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
 
139
  lora_arguments = (
140
  {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
141
  )
142
+ features.pop('prompt_length', None)
143
  output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
144
  output_tokens = output_states[0]
145
  features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})