Files changed (1) hide show
  1. bert_layers.py +4 -1
bert_layers.py CHANGED
@@ -25,6 +25,8 @@ from .bert_padding import (index_first_axis,
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
27
 
 
 
28
  try:
29
  from .flash_attn_triton import flash_attn_qkvpacked_func
30
  except ImportError as e:
@@ -564,7 +566,8 @@ class BertModel(BertPreTrainedModel):
564
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
565
  ```
566
  """
567
-
 
568
  def __init__(self, config, add_pooling_layer=True):
569
  super(BertModel, self).__init__(config)
570
  self.embeddings = BertEmbeddings(config)
 
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
27
 
28
+ from .configuration_bert import BertConfig
29
+
30
  try:
31
  from .flash_attn_triton import flash_attn_qkvpacked_func
32
  except ImportError as e:
 
566
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
567
  ```
568
  """
569
+ config_class = BertConfig
570
+
571
  def __init__(self, config, add_pooling_layer=True):
572
  super(BertModel, self).__init__(config)
573
  self.embeddings = BertEmbeddings(config)