Markus28 commited on
Commit
95ca1a8
·
1 Parent(s): 463061d

fix: try to skip initialization of task type embeddings

Browse files
Files changed (1) hide show
  1. modeling_bert.py +4 -2
modeling_bert.py CHANGED
@@ -145,7 +145,7 @@ def _init_weights(module, initializer_range=0.02):
145
  nn.init.normal_(module.weight, std=initializer_range)
146
  if module.bias is not None:
147
  nn.init.zeros_(module.bias)
148
- elif isinstance(module, nn.Embedding):
149
  nn.init.normal_(module.weight, std=initializer_range)
150
  if module.padding_idx is not None:
151
  nn.init.zeros_(module.weight[module.padding_idx])
@@ -346,12 +346,14 @@ class BertModel(BertPreTrainedModel):
346
  self.pooler = BertPooler(config) if add_pooling_layer else None
347
  self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
348
 
349
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
350
  # We now initialize the task embeddings to 0; We do not use task types during
351
  # pretraining. When we start using task types during embedding training,
352
  # we want the model to behave exactly as in pretraining (i.e. task types
353
  # have no effect).
354
  nn.init.zeros_(self.task_type_embeddings.weight)
 
 
 
355
 
356
  def forward(
357
  self,
 
145
  nn.init.normal_(module.weight, std=initializer_range)
146
  if module.bias is not None:
147
  nn.init.zeros_(module.bias)
148
+ elif isinstance(module, nn.Embedding) and not module.skip_init:
149
  nn.init.normal_(module.weight, std=initializer_range)
150
  if module.padding_idx is not None:
151
  nn.init.zeros_(module.weight[module.padding_idx])
 
346
  self.pooler = BertPooler(config) if add_pooling_layer else None
347
  self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
348
 
 
349
  # We now initialize the task embeddings to 0; We do not use task types during
350
  # pretraining. When we start using task types during embedding training,
351
  # we want the model to behave exactly as in pretraining (i.e. task types
352
  # have no effect).
353
  nn.init.zeros_(self.task_type_embeddings.weight)
354
+ self.task_type_embeddings.skip_init = True
355
+ # The following code should skip the embeddings layer
356
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
357
 
358
  def forward(
359
  self,