Safetensors
English
bert_hash
custom_code
davidmezzetti commited on
Commit
5682631
·
1 Parent(s): bb63cc1

Support resizing token embeddings

Browse files
Files changed (1) hide show
  1. modeling_bert_hash.py +6 -0
modeling_bert_hash.py CHANGED
@@ -154,6 +154,12 @@ class BertHashModel(BertPreTrainedModel):
154
  # Initialize weights and apply final processing
155
  self.post_init()
156
 
 
 
 
 
 
 
157
  def _prune_heads(self, heads_to_prune):
158
  """
159
  Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
 
154
  # Initialize weights and apply final processing
155
  self.post_init()
156
 
157
+ def get_input_embeddings(self):
158
+ return self.embeddings.word_embeddings.embeddings
159
+
160
+ def set_input_embeddings(self, value):
161
+ self.embeddings.word_embeddings.embeddings = value
162
+
163
  def _prune_heads(self, heads_to_prune):
164
  """
165
  Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base