duzx16 commited on
Commit
189e5df
1 Parent(s): f2191d0

Add get_input_embeddings

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +3 -0
modeling_chatglm.py CHANGED
@@ -702,6 +702,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
702
  dtype=config.torch_dtype, **init_kwargs)
703
  self.gradient_checkpointing = False
704
 
 
 
 
705
  def forward(
706
  self,
707
  input_ids,
 
702
  dtype=config.torch_dtype, **init_kwargs)
703
  self.gradient_checkpointing = False
704
 
705
+ def get_input_embeddings(self):
706
+ return self.embedding.word_embeddings
707
+
708
  def forward(
709
  self,
710
  input_ids,