Commit
•
91c61bf
1
Parent(s):
2868a91
Upload model
Browse files- bert_classification_model.py +3 -0
- config.json +1 -1
- pytorch_model.bin +1 -1
bert_classification_model.py
CHANGED
@@ -4,8 +4,11 @@ from typing import Tuple
|
|
4 |
from transformers import PreTrainedModel, BertModel
|
5 |
from torch import nn
|
6 |
|
|
|
|
|
7 |
|
8 |
class BertClassificationModel(PreTrainedModel):
|
|
|
9 |
|
10 |
def __init__(self, config, num_main_segment=None, num_sub_segment=None):
|
11 |
super(BertClassificationModel, self).__init__(config=config)
|
|
|
4 |
from transformers import PreTrainedModel, BertModel
|
5 |
from torch import nn
|
6 |
|
7 |
+
from bert_classification_config import BertClassificationConfig
|
8 |
+
|
9 |
|
10 |
class BertClassificationModel(PreTrainedModel):
|
11 |
+
config_class = BertClassificationConfig
|
12 |
|
13 |
def __init__(self, config, num_main_segment=None, num_sub_segment=None):
|
14 |
super(BertClassificationModel, self).__init__(config=config)
|
config.json
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
"intermediate_size": 3072,
|
15 |
"layer_norm_eps": 1e-12,
|
16 |
"max_position_embeddings": 512,
|
17 |
-
"model_type": "bert",
|
18 |
"num_attention_heads": 12,
|
19 |
"num_hidden_layers": 12,
|
20 |
"num_main_segment": 0,
|
|
|
14 |
"intermediate_size": 3072,
|
15 |
"layer_norm_eps": 1e-12,
|
16 |
"max_position_embeddings": 512,
|
17 |
+
"model_type": "bert-classification",
|
18 |
"num_attention_heads": 12,
|
19 |
"num_hidden_layers": 12,
|
20 |
"num_main_segment": 0,
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 671857125
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc4263b34d694e1985ba5f30ff508f1d307614e09914270b24136661769254c9
|
3 |
size 671857125
|