Commit
·
5682631
1
Parent(s):
bb63cc1
Support resizing token embeddings
Browse files- modeling_bert_hash.py +6 -0
modeling_bert_hash.py
CHANGED
|
@@ -154,6 +154,12 @@ class BertHashModel(BertPreTrainedModel):
|
|
| 154 |
# Initialize weights and apply final processing
|
| 155 |
self.post_init()
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def _prune_heads(self, heads_to_prune):
|
| 158 |
"""
|
| 159 |
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
|
|
| 154 |
# Initialize weights and apply final processing
|
| 155 |
self.post_init()
|
| 156 |
|
| 157 |
+
def get_input_embeddings(self):
|
| 158 |
+
return self.embeddings.word_embeddings.embeddings
|
| 159 |
+
|
| 160 |
+
def set_input_embeddings(self, value):
|
| 161 |
+
self.embeddings.word_embeddings.embeddings = value
|
| 162 |
+
|
| 163 |
def _prune_heads(self, heads_to_prune):
|
| 164 |
"""
|
| 165 |
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|