oweller2 commited on
Commit
cf03b9b
1 Parent(s): 4753b37
Files changed (1) hide show
  1. modeling_flexbert.py +13 -4
modeling_flexbert.py CHANGED
@@ -1537,17 +1537,26 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1537
  self._init_weights(reset_params=False)
1538
 
1539
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
 
1540
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1541
- if module:
1542
- self._init_module_weights(module)
 
 
 
 
 
 
 
 
1543
  else:
1544
  assert isinstance(reset_params, bool)
1545
  self.bert._init_weights(reset_params=reset_params)
1546
  self.lm_head._init_weights(reset_params=reset_params)
1547
-
1548
  if not self.config.tie_word_embeddings:
1549
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1550
-
1551
  @classmethod
1552
  def from_composer(
1553
  cls,
 
1537
  self._init_weights(reset_params=False)
1538
 
1539
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1540
+ # Handle the XOR condition
1541
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1542
+
1543
+ if module is not None:
1544
+ # Add basic initialization for common module types
1545
+ if isinstance(module, (nn.Linear, nn.Embedding)):
1546
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1547
+ if isinstance(module, nn.Linear) and module.bias is not None:
1548
+ module.bias.data.zero_()
1549
+ elif isinstance(module, nn.LayerNorm):
1550
+ module.bias.data.zero_()
1551
+ module.weight.data.fill_(1.0)
1552
  else:
1553
  assert isinstance(reset_params, bool)
1554
  self.bert._init_weights(reset_params=reset_params)
1555
  self.lm_head._init_weights(reset_params=reset_params)
1556
+
1557
  if not self.config.tie_word_embeddings:
1558
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1559
+
1560
  @classmethod
1561
  def from_composer(
1562
  cls,