ElmehdiSMILI commited on
Commit
8ae731b
·
verified ·
1 Parent(s): 26e9f97

Update modeling_jais.py

Browse files
Files changed (1) hide show
  1. modeling_jais.py +6 -0
modeling_jais.py CHANGED
@@ -535,6 +535,11 @@ class JAISPreTrainedModel(PreTrainedModel):
535
  stddev = self.config.initializer_range * mup_init_scale / math.sqrt(2 * self.config.n_layer)
536
  p.data.normal_(mean=0.0, std=stddev)
537
 
 
 
 
 
 
538
  JAIS_START_DOCSTRING = r"""
539
 
540
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@@ -861,6 +866,7 @@ class JAISModel(JAISPreTrainedModel):
861
  hidden_states = inputs_embeds + position_embeds
862
  else:
863
  hidden_states = inputs_embeds
 
864
  scale_factor_hidden = torch.tensor(float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device)
865
  hidden_states = hidden_states * scale_factor_hidden
866
 
 
535
  stddev = self.config.initializer_range * mup_init_scale / math.sqrt(2 * self.config.n_layer)
536
  p.data.normal_(mean=0.0, std=stddev)
537
 
538
+ def _set_gradient_checkpointing(self, module, value=False):
539
+ if isinstance(module, JAISModel):
540
+ module.gradient_checkpointing = value
541
+
542
+
543
  JAIS_START_DOCSTRING = r"""
544
 
545
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
 
866
  hidden_states = inputs_embeds + position_embeds
867
  else:
868
  hidden_states = inputs_embeds
869
+ '''hidden_states *= torch.tensor(float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device) '''
870
  scale_factor_hidden = torch.tensor(float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device)
871
  hidden_states = hidden_states * scale_factor_hidden
872