KevinGeertjens commited on
Commit
91c61bf
1 Parent(s): 2868a91

Upload model

Browse files
bert_classification_model.py CHANGED
@@ -4,8 +4,11 @@ from typing import Tuple
4
  from transformers import PreTrainedModel, BertModel
5
  from torch import nn
6
 
 
 
7
 
8
  class BertClassificationModel(PreTrainedModel):
 
9
 
10
  def __init__(self, config, num_main_segment=None, num_sub_segment=None):
11
  super(BertClassificationModel, self).__init__(config=config)
 
4
  from transformers import PreTrainedModel, BertModel
5
  from torch import nn
6
 
7
+ from bert_classification_config import BertClassificationConfig
8
+
9
 
10
  class BertClassificationModel(PreTrainedModel):
11
+ config_class = BertClassificationConfig
12
 
13
  def __init__(self, config, num_main_segment=None, num_sub_segment=None):
14
  super(BertClassificationModel, self).__init__(config=config)
config.json CHANGED
@@ -14,7 +14,7 @@
14
  "intermediate_size": 3072,
15
  "layer_norm_eps": 1e-12,
16
  "max_position_embeddings": 512,
17
- "model_type": "bert",
18
  "num_attention_heads": 12,
19
  "num_hidden_layers": 12,
20
  "num_main_segment": 0,
 
14
  "intermediate_size": 3072,
15
  "layer_norm_eps": 1e-12,
16
  "max_position_embeddings": 512,
17
+ "model_type": "bert-classification",
18
  "num_attention_heads": 12,
19
  "num_hidden_layers": 12,
20
  "num_main_segment": 0,
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18b607e5e17d75afbb4b65e8c03fda9e28d87a1c9259a4c9cc5f291696317003
3
  size 671857125
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc4263b34d694e1985ba5f30ff508f1d307614e09914270b24136661769254c9
3
  size 671857125