Markus28 commited on
Commit
f8b62b4
1 Parent(s): 87b642a

feat: added top-level docstring, made it compatible with AutoModel

Browse files
Files changed (1) hide show
  1. modeling_bert.py +10 -5
modeling_bert.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  # Copyright (c) 2022, Tri Dao.
2
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
@@ -297,12 +304,10 @@ class BertPreTrainedModel(nn.Module):
297
 
298
  def __init__(self, config, *inputs, **kwargs):
299
  super().__init__()
300
- if not isinstance(config, JinaBertConfig):
301
  raise ValueError(
302
- "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
303
- "To create a model from a Google pretrained model use "
304
- "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
305
- self.__class__.__name__, self.__class__.__name__
306
  )
307
  )
308
  self.config = config
 
1
+ """ Implementation of BERT, using ALiBi and Flash Attention
2
+
3
+ The implementation was adopted from
4
+ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
5
+ and made modifications to use ALiBi.
6
+ """
7
+
8
  # Copyright (c) 2022, Tri Dao.
9
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
10
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
 
304
 
305
  def __init__(self, config, *inputs, **kwargs):
306
  super().__init__()
307
+ if not config.__class__.__name__ == 'JinaBertConfig':
308
  raise ValueError(
309
+ "Parameter config in `{}(config)` should be an instance of class `JinaBertConfig`.".format(
310
+ self.__class__.__name__,
 
 
311
  )
312
  )
313
  self.config = config