Allow device auto map

#8
by Jackmin108 - opened
Files changed (1) hide show
  1. modeling_bert.py +1 -0
modeling_bert.py CHANGED
@@ -956,6 +956,7 @@ class JinaBertPreTrainedModel(PreTrainedModel):
956
  load_tf_weights = load_tf_weights_in_bert
957
  base_model_prefix = "bert"
958
  supports_gradient_checkpointing = True
 
959
 
960
  def _init_weights(self, module):
961
  """Initialize the weights"""
 
956
  load_tf_weights = load_tf_weights_in_bert
957
  base_model_prefix = "bert"
958
  supports_gradient_checkpointing = True
959
+ _no_split_modules = ["JinaBertLayer"]
960
 
961
  def _init_weights(self, module):
962
  """Initialize the weights"""