mansaripo commited on
Commit
f2ef64c
·
verified ·
1 Parent(s): 4781880

Update modeling_cloverlm.py

Browse files
Files changed (1) hide show
  1. modeling_cloverlm.py +4 -1
modeling_cloverlm.py CHANGED
@@ -209,11 +209,13 @@ class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
209
  config_class = CloverLMConfig
210
  supports_gradient_checkpointing = False
211
  _no_split_modules = ["_Block"]
212
- _tied_weights_keys = {"transformer.linear.weight": "transformer.emb.weight"}
213
  _tp_plan = {}
214
 
215
  def __init__(self, config: CloverLMConfig):
216
  super().__init__(config)
 
 
217
  self.transformer = _Transformer(
218
  vocab_size=config.vocab_size,
219
  num_blocks=config.num_blocks,
@@ -226,6 +228,7 @@ class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
226
  weight_tying=config.weight_tying,
227
  attn_backend=config.attn_backend,
228
  )
 
229
 
230
  def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
231
  logits = self.transformer(input_ids)
 
209
  config_class = CloverLMConfig
210
  supports_gradient_checkpointing = False
211
  _no_split_modules = ["_Block"]
212
+ _tied_weights_keys = ["transformer.linear.weight"]
213
  _tp_plan = {}
214
 
215
  def __init__(self, config: CloverLMConfig):
216
  super().__init__(config)
217
+ self.all_tied_weights_keys = {k: "transformer.emb.weight"
218
+ for k in (self._tied_weights_keys or [])}
219
  self.transformer = _Transformer(
220
  vocab_size=config.vocab_size,
221
  num_blocks=config.num_blocks,
 
228
  weight_tying=config.weight_tying,
229
  attn_backend=config.attn_backend,
230
  )
231
+ self.post_init()
232
 
233
  def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
234
  logits = self.transformer(input_ids)