Jackmin801 commited on
Commit
f4624e0
1 Parent(s): 43f3955

feat: no splt modules for device auto map

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -0
modeling_bert.py CHANGED
@@ -956,6 +956,7 @@ class JinaBertPreTrainedModel(PreTrainedModel):
956
  load_tf_weights = load_tf_weights_in_bert
957
  base_model_prefix = "bert"
958
  supports_gradient_checkpointing = True
 
959
 
960
  def _init_weights(self, module):
961
  """Initialize the weights"""
 
956
  load_tf_weights = load_tf_weights_in_bert
957
  base_model_prefix = "bert"
958
  supports_gradient_checkpointing = True
959
+ _no_split_modules = ["JinaBertLayer"]
960
 
961
  def _init_weights(self, module):
962
  """Initialize the weights"""